thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use agent_rules::load_worktree_rules_file;
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use fs::Fs;
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
  19    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  20    LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  21    PaymentRequiredError, Role, StopReason, TokenUsage,
  22};
  23use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  24use project::{Project, Worktree};
  25use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
  26use schemars::JsonSchema;
  27use serde::{Deserialize, Serialize};
  28use settings::Settings;
  29use util::{ResultExt as _, TryFutureExt as _, post_inc};
  30use uuid::Uuid;
  31
  32use crate::context::{AssistantContext, ContextId, format_context_as_string};
  33use crate::thread_store::{
  34    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  35    SerializedToolUse,
  36};
  37use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  38
  39#[derive(Debug, Clone, Copy)]
  40pub enum RequestKind {
  41    Chat,
  42    /// Used when summarizing a thread.
  43    Summarize,
  44}
  45
  46#[derive(
  47    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  48)]
  49pub struct ThreadId(Arc<str>);
  50
  51impl ThreadId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for ThreadId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63impl From<&str> for ThreadId {
  64    fn from(value: &str) -> Self {
  65        Self(value.into())
  66    }
  67}
  68
  69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  70pub struct MessageId(pub(crate) usize);
  71
  72impl MessageId {
  73    fn post_inc(&mut self) -> Self {
  74        Self(post_inc(&mut self.0))
  75    }
  76}
  77
  78/// A message in a [`Thread`].
  79#[derive(Debug, Clone)]
  80pub struct Message {
  81    pub id: MessageId,
  82    pub role: Role,
  83    pub segments: Vec<MessageSegment>,
  84    pub context: String,
  85}
  86
  87impl Message {
  88    /// Returns whether the message contains any meaningful text that should be displayed
  89    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  90    pub fn should_display_content(&self) -> bool {
  91        self.segments.iter().all(|segment| segment.should_display())
  92    }
  93
  94    pub fn push_thinking(&mut self, text: &str) {
  95        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  96            segment.push_str(text);
  97        } else {
  98            self.segments
  99                .push(MessageSegment::Thinking(text.to_string()));
 100        }
 101    }
 102
 103    pub fn push_text(&mut self, text: &str) {
 104        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 105            segment.push_str(text);
 106        } else {
 107            self.segments.push(MessageSegment::Text(text.to_string()));
 108        }
 109    }
 110
 111    pub fn to_string(&self) -> String {
 112        let mut result = String::new();
 113
 114        if !self.context.is_empty() {
 115            result.push_str(&self.context);
 116        }
 117
 118        for segment in &self.segments {
 119            match segment {
 120                MessageSegment::Text(text) => result.push_str(text),
 121                MessageSegment::Thinking(text) => {
 122                    result.push_str("<think>");
 123                    result.push_str(text);
 124                    result.push_str("</think>");
 125                }
 126            }
 127        }
 128
 129        result
 130    }
 131}
 132
 133#[derive(Debug, Clone, PartialEq, Eq)]
 134pub enum MessageSegment {
 135    Text(String),
 136    Thinking(String),
 137}
 138
 139impl MessageSegment {
 140    pub fn text_mut(&mut self) -> &mut String {
 141        match self {
 142            Self::Text(text) => text,
 143            Self::Thinking(text) => text,
 144        }
 145    }
 146
 147    pub fn should_display(&self) -> bool {
 148        // We add USING_TOOL_MARKER when making a request that includes tool uses
 149        // without non-whitespace text around them, and this can cause the model
 150        // to mimic the pattern, so we consider those segments not displayable.
 151        match self {
 152            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 153            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 154        }
 155    }
 156}
 157
 158#[derive(Debug, Clone, Serialize, Deserialize)]
 159pub struct ProjectSnapshot {
 160    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 161    pub unsaved_buffer_paths: Vec<String>,
 162    pub timestamp: DateTime<Utc>,
 163}
 164
 165#[derive(Debug, Clone, Serialize, Deserialize)]
 166pub struct WorktreeSnapshot {
 167    pub worktree_path: String,
 168    pub git_state: Option<GitState>,
 169}
 170
 171#[derive(Debug, Clone, Serialize, Deserialize)]
 172pub struct GitState {
 173    pub remote_url: Option<String>,
 174    pub head_sha: Option<String>,
 175    pub current_branch: Option<String>,
 176    pub diff: Option<String>,
 177}
 178
 179#[derive(Clone)]
 180pub struct ThreadCheckpoint {
 181    message_id: MessageId,
 182    git_checkpoint: GitStoreCheckpoint,
 183}
 184
 185#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 186pub enum ThreadFeedback {
 187    Positive,
 188    Negative,
 189}
 190
 191pub enum LastRestoreCheckpoint {
 192    Pending {
 193        message_id: MessageId,
 194    },
 195    Error {
 196        message_id: MessageId,
 197        error: String,
 198    },
 199}
 200
 201impl LastRestoreCheckpoint {
 202    pub fn message_id(&self) -> MessageId {
 203        match self {
 204            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 205            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 206        }
 207    }
 208}
 209
 210#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 211pub enum DetailedSummaryState {
 212    #[default]
 213    NotGenerated,
 214    Generating {
 215        message_id: MessageId,
 216    },
 217    Generated {
 218        text: SharedString,
 219        message_id: MessageId,
 220    },
 221}
 222
 223#[derive(Default)]
 224pub struct TotalTokenUsage {
 225    pub total: usize,
 226    pub max: usize,
 227    pub ratio: TokenUsageRatio,
 228}
 229
 230#[derive(Default, PartialEq, Eq)]
 231pub enum TokenUsageRatio {
 232    #[default]
 233    Normal,
 234    Warning,
 235    Exceeded,
 236}
 237
 238/// A thread of conversation with the LLM.
 239pub struct Thread {
 240    id: ThreadId,
 241    updated_at: DateTime<Utc>,
 242    summary: Option<SharedString>,
 243    pending_summary: Task<Option<()>>,
 244    detailed_summary_state: DetailedSummaryState,
 245    messages: Vec<Message>,
 246    next_message_id: MessageId,
 247    context: BTreeMap<ContextId, AssistantContext>,
 248    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 249    system_prompt_context: Option<AssistantSystemPromptContext>,
 250    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 251    completion_count: usize,
 252    pending_completions: Vec<PendingCompletion>,
 253    project: Entity<Project>,
 254    prompt_builder: Arc<PromptBuilder>,
 255    tools: Arc<ToolWorkingSet>,
 256    tool_use: ToolUseState,
 257    action_log: Entity<ActionLog>,
 258    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 259    pending_checkpoint: Option<ThreadCheckpoint>,
 260    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 261    cumulative_token_usage: TokenUsage,
 262    feedback: Option<ThreadFeedback>,
 263    message_feedback: HashMap<MessageId, ThreadFeedback>,
 264}
 265
 266impl Thread {
 267    pub fn new(
 268        project: Entity<Project>,
 269        tools: Arc<ToolWorkingSet>,
 270        prompt_builder: Arc<PromptBuilder>,
 271        cx: &mut Context<Self>,
 272    ) -> Self {
 273        Self {
 274            id: ThreadId::new(),
 275            updated_at: Utc::now(),
 276            summary: None,
 277            pending_summary: Task::ready(None),
 278            detailed_summary_state: DetailedSummaryState::NotGenerated,
 279            messages: Vec::new(),
 280            next_message_id: MessageId(0),
 281            context: BTreeMap::default(),
 282            context_by_message: HashMap::default(),
 283            system_prompt_context: None,
 284            checkpoints_by_message: HashMap::default(),
 285            completion_count: 0,
 286            pending_completions: Vec::new(),
 287            project: project.clone(),
 288            prompt_builder,
 289            tools: tools.clone(),
 290            last_restore_checkpoint: None,
 291            pending_checkpoint: None,
 292            tool_use: ToolUseState::new(tools.clone()),
 293            action_log: cx.new(|_| ActionLog::new(project.clone())),
 294            initial_project_snapshot: {
 295                let project_snapshot = Self::project_snapshot(project, cx);
 296                cx.foreground_executor()
 297                    .spawn(async move { Some(project_snapshot.await) })
 298                    .shared()
 299            },
 300            cumulative_token_usage: TokenUsage::default(),
 301            feedback: None,
 302            message_feedback: HashMap::default(),
 303        }
 304    }
 305
 306    pub fn deserialize(
 307        id: ThreadId,
 308        serialized: SerializedThread,
 309        project: Entity<Project>,
 310        tools: Arc<ToolWorkingSet>,
 311        prompt_builder: Arc<PromptBuilder>,
 312        cx: &mut Context<Self>,
 313    ) -> Self {
 314        let next_message_id = MessageId(
 315            serialized
 316                .messages
 317                .last()
 318                .map(|message| message.id.0 + 1)
 319                .unwrap_or(0),
 320        );
 321        let tool_use =
 322            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 323
 324        Self {
 325            id,
 326            updated_at: serialized.updated_at,
 327            summary: Some(serialized.summary),
 328            pending_summary: Task::ready(None),
 329            detailed_summary_state: serialized.detailed_summary_state,
 330            messages: serialized
 331                .messages
 332                .into_iter()
 333                .map(|message| Message {
 334                    id: message.id,
 335                    role: message.role,
 336                    segments: message
 337                        .segments
 338                        .into_iter()
 339                        .map(|segment| match segment {
 340                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 341                            SerializedMessageSegment::Thinking { text } => {
 342                                MessageSegment::Thinking(text)
 343                            }
 344                        })
 345                        .collect(),
 346                    context: message.context,
 347                })
 348                .collect(),
 349            next_message_id,
 350            context: BTreeMap::default(),
 351            context_by_message: HashMap::default(),
 352            system_prompt_context: None,
 353            checkpoints_by_message: HashMap::default(),
 354            completion_count: 0,
 355            pending_completions: Vec::new(),
 356            last_restore_checkpoint: None,
 357            pending_checkpoint: None,
 358            project: project.clone(),
 359            prompt_builder,
 360            tools,
 361            tool_use,
 362            action_log: cx.new(|_| ActionLog::new(project)),
 363            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 364            cumulative_token_usage: serialized.cumulative_token_usage,
 365            feedback: None,
 366            message_feedback: HashMap::default(),
 367        }
 368    }
 369
 370    pub fn id(&self) -> &ThreadId {
 371        &self.id
 372    }
 373
 374    pub fn is_empty(&self) -> bool {
 375        self.messages.is_empty()
 376    }
 377
 378    pub fn updated_at(&self) -> DateTime<Utc> {
 379        self.updated_at
 380    }
 381
 382    pub fn touch_updated_at(&mut self) {
 383        self.updated_at = Utc::now();
 384    }
 385
 386    pub fn summary(&self) -> Option<SharedString> {
 387        self.summary.clone()
 388    }
 389
 390    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 391
 392    pub fn summary_or_default(&self) -> SharedString {
 393        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 394    }
 395
 396    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 397        let Some(current_summary) = &self.summary else {
 398            // Don't allow setting summary until generated
 399            return;
 400        };
 401
 402        let mut new_summary = new_summary.into();
 403
 404        if new_summary.is_empty() {
 405            new_summary = Self::DEFAULT_SUMMARY;
 406        }
 407
 408        if current_summary != &new_summary {
 409            self.summary = Some(new_summary);
 410            cx.emit(ThreadEvent::SummaryChanged);
 411        }
 412    }
 413
 414    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 415        self.latest_detailed_summary()
 416            .unwrap_or_else(|| self.text().into())
 417    }
 418
 419    fn latest_detailed_summary(&self) -> Option<SharedString> {
 420        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 421            Some(text.clone())
 422        } else {
 423            None
 424        }
 425    }
 426
 427    pub fn message(&self, id: MessageId) -> Option<&Message> {
 428        self.messages.iter().find(|message| message.id == id)
 429    }
 430
 431    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 432        self.messages.iter()
 433    }
 434
 435    pub fn is_generating(&self) -> bool {
 436        !self.pending_completions.is_empty() || !self.all_tools_finished()
 437    }
 438
 439    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 440        &self.tools
 441    }
 442
 443    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 444        self.tool_use
 445            .pending_tool_uses()
 446            .into_iter()
 447            .find(|tool_use| &tool_use.id == id)
 448    }
 449
 450    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 451        self.tool_use
 452            .pending_tool_uses()
 453            .into_iter()
 454            .filter(|tool_use| tool_use.status.needs_confirmation())
 455    }
 456
 457    pub fn has_pending_tool_uses(&self) -> bool {
 458        !self.tool_use.pending_tool_uses().is_empty()
 459    }
 460
 461    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 462        self.checkpoints_by_message.get(&id).cloned()
 463    }
 464
 465    pub fn restore_checkpoint(
 466        &mut self,
 467        checkpoint: ThreadCheckpoint,
 468        cx: &mut Context<Self>,
 469    ) -> Task<Result<()>> {
 470        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 471            message_id: checkpoint.message_id,
 472        });
 473        cx.emit(ThreadEvent::CheckpointChanged);
 474        cx.notify();
 475
 476        let git_store = self.project().read(cx).git_store().clone();
 477        let restore = git_store.update(cx, |git_store, cx| {
 478            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 479        });
 480
 481        cx.spawn(async move |this, cx| {
 482            let result = restore.await;
 483            this.update(cx, |this, cx| {
 484                if let Err(err) = result.as_ref() {
 485                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 486                        message_id: checkpoint.message_id,
 487                        error: err.to_string(),
 488                    });
 489                } else {
 490                    this.truncate(checkpoint.message_id, cx);
 491                    this.last_restore_checkpoint = None;
 492                }
 493                this.pending_checkpoint = None;
 494                cx.emit(ThreadEvent::CheckpointChanged);
 495                cx.notify();
 496            })?;
 497            result
 498        })
 499    }
 500
 501    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 502        let pending_checkpoint = if self.is_generating() {
 503            return;
 504        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 505            checkpoint
 506        } else {
 507            return;
 508        };
 509
 510        let git_store = self.project.read(cx).git_store().clone();
 511        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 512        cx.spawn(async move |this, cx| match final_checkpoint.await {
 513            Ok(final_checkpoint) => {
 514                let equal = git_store
 515                    .update(cx, |store, cx| {
 516                        store.compare_checkpoints(
 517                            pending_checkpoint.git_checkpoint.clone(),
 518                            final_checkpoint.clone(),
 519                            cx,
 520                        )
 521                    })?
 522                    .await
 523                    .unwrap_or(false);
 524
 525                if equal {
 526                    git_store
 527                        .update(cx, |store, cx| {
 528                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 529                        })?
 530                        .detach();
 531                } else {
 532                    this.update(cx, |this, cx| {
 533                        this.insert_checkpoint(pending_checkpoint, cx)
 534                    })?;
 535                }
 536
 537                git_store
 538                    .update(cx, |store, cx| {
 539                        store.delete_checkpoint(final_checkpoint, cx)
 540                    })?
 541                    .detach();
 542
 543                Ok(())
 544            }
 545            Err(_) => this.update(cx, |this, cx| {
 546                this.insert_checkpoint(pending_checkpoint, cx)
 547            }),
 548        })
 549        .detach();
 550    }
 551
 552    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 553        self.checkpoints_by_message
 554            .insert(checkpoint.message_id, checkpoint);
 555        cx.emit(ThreadEvent::CheckpointChanged);
 556        cx.notify();
 557    }
 558
 559    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 560        self.last_restore_checkpoint.as_ref()
 561    }
 562
 563    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 564        let Some(message_ix) = self
 565            .messages
 566            .iter()
 567            .rposition(|message| message.id == message_id)
 568        else {
 569            return;
 570        };
 571        for deleted_message in self.messages.drain(message_ix..) {
 572            self.context_by_message.remove(&deleted_message.id);
 573            self.checkpoints_by_message.remove(&deleted_message.id);
 574        }
 575        cx.notify();
 576    }
 577
 578    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 579        self.context_by_message
 580            .get(&id)
 581            .into_iter()
 582            .flat_map(|context| {
 583                context
 584                    .iter()
 585                    .filter_map(|context_id| self.context.get(&context_id))
 586            })
 587    }
 588
 589    /// Returns whether all of the tool uses have finished running.
 590    pub fn all_tools_finished(&self) -> bool {
 591        // If the only pending tool uses left are the ones with errors, then
 592        // that means that we've finished running all of the pending tools.
 593        self.tool_use
 594            .pending_tool_uses()
 595            .iter()
 596            .all(|tool_use| tool_use.status.is_error())
 597    }
 598
 599    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 600        self.tool_use.tool_uses_for_message(id, cx)
 601    }
 602
 603    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 604        self.tool_use.tool_results_for_message(id)
 605    }
 606
 607    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 608        self.tool_use.tool_result(id)
 609    }
 610
 611    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 612        self.tool_use.message_has_tool_results(message_id)
 613    }
 614
 615    pub fn insert_user_message(
 616        &mut self,
 617        text: impl Into<String>,
 618        context: Vec<AssistantContext>,
 619        git_checkpoint: Option<GitStoreCheckpoint>,
 620        cx: &mut Context<Self>,
 621    ) -> MessageId {
 622        let text = text.into();
 623
 624        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 625
 626        // Filter out contexts that have already been included in previous messages
 627        let new_context: Vec<_> = context
 628            .into_iter()
 629            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 630            .collect();
 631
 632        if !new_context.is_empty() {
 633            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 634                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 635                    message.context = context_string;
 636                }
 637            }
 638
 639            self.action_log.update(cx, |log, cx| {
 640                // Track all buffers added as context
 641                for ctx in &new_context {
 642                    match ctx {
 643                        AssistantContext::File(file_ctx) => {
 644                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 645                        }
 646                        AssistantContext::Directory(dir_ctx) => {
 647                            for context_buffer in &dir_ctx.context_buffers {
 648                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 649                            }
 650                        }
 651                        AssistantContext::Symbol(symbol_ctx) => {
 652                            log.buffer_added_as_context(
 653                                symbol_ctx.context_symbol.buffer.clone(),
 654                                cx,
 655                            );
 656                        }
 657                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 658                    }
 659                }
 660            });
 661        }
 662
 663        let context_ids = new_context
 664            .iter()
 665            .map(|context| context.id())
 666            .collect::<Vec<_>>();
 667        self.context.extend(
 668            new_context
 669                .into_iter()
 670                .map(|context| (context.id(), context)),
 671        );
 672        self.context_by_message.insert(message_id, context_ids);
 673
 674        if let Some(git_checkpoint) = git_checkpoint {
 675            self.pending_checkpoint = Some(ThreadCheckpoint {
 676                message_id,
 677                git_checkpoint,
 678            });
 679        }
 680        message_id
 681    }
 682
 683    pub fn insert_message(
 684        &mut self,
 685        role: Role,
 686        segments: Vec<MessageSegment>,
 687        cx: &mut Context<Self>,
 688    ) -> MessageId {
 689        let id = self.next_message_id.post_inc();
 690        self.messages.push(Message {
 691            id,
 692            role,
 693            segments,
 694            context: String::new(),
 695        });
 696        self.touch_updated_at();
 697        cx.emit(ThreadEvent::MessageAdded(id));
 698        id
 699    }
 700
 701    pub fn edit_message(
 702        &mut self,
 703        id: MessageId,
 704        new_role: Role,
 705        new_segments: Vec<MessageSegment>,
 706        cx: &mut Context<Self>,
 707    ) -> bool {
 708        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 709            return false;
 710        };
 711        message.role = new_role;
 712        message.segments = new_segments;
 713        self.touch_updated_at();
 714        cx.emit(ThreadEvent::MessageEdited(id));
 715        true
 716    }
 717
 718    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 719        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 720            return false;
 721        };
 722        self.messages.remove(index);
 723        self.context_by_message.remove(&id);
 724        self.touch_updated_at();
 725        cx.emit(ThreadEvent::MessageDeleted(id));
 726        true
 727    }
 728
 729    /// Returns the representation of this [`Thread`] in a textual form.
 730    ///
 731    /// This is the representation we use when attaching a thread as context to another thread.
 732    pub fn text(&self) -> String {
 733        let mut text = String::new();
 734
 735        for message in &self.messages {
 736            text.push_str(match message.role {
 737                language_model::Role::User => "User:",
 738                language_model::Role::Assistant => "Assistant:",
 739                language_model::Role::System => "System:",
 740            });
 741            text.push('\n');
 742
 743            for segment in &message.segments {
 744                match segment {
 745                    MessageSegment::Text(content) => text.push_str(content),
 746                    MessageSegment::Thinking(content) => {
 747                        text.push_str(&format!("<think>{}</think>", content))
 748                    }
 749                }
 750            }
 751            text.push('\n');
 752        }
 753
 754        text
 755    }
 756
 757    /// Serializes this thread into a format for storage or telemetry.
 758    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 759        let initial_project_snapshot = self.initial_project_snapshot.clone();
 760        cx.spawn(async move |this, cx| {
 761            let initial_project_snapshot = initial_project_snapshot.await;
 762            this.read_with(cx, |this, cx| SerializedThread {
 763                version: SerializedThread::VERSION.to_string(),
 764                summary: this.summary_or_default(),
 765                updated_at: this.updated_at(),
 766                messages: this
 767                    .messages()
 768                    .map(|message| SerializedMessage {
 769                        id: message.id,
 770                        role: message.role,
 771                        segments: message
 772                            .segments
 773                            .iter()
 774                            .map(|segment| match segment {
 775                                MessageSegment::Text(text) => {
 776                                    SerializedMessageSegment::Text { text: text.clone() }
 777                                }
 778                                MessageSegment::Thinking(text) => {
 779                                    SerializedMessageSegment::Thinking { text: text.clone() }
 780                                }
 781                            })
 782                            .collect(),
 783                        tool_uses: this
 784                            .tool_uses_for_message(message.id, cx)
 785                            .into_iter()
 786                            .map(|tool_use| SerializedToolUse {
 787                                id: tool_use.id,
 788                                name: tool_use.name,
 789                                input: tool_use.input,
 790                            })
 791                            .collect(),
 792                        tool_results: this
 793                            .tool_results_for_message(message.id)
 794                            .into_iter()
 795                            .map(|tool_result| SerializedToolResult {
 796                                tool_use_id: tool_result.tool_use_id.clone(),
 797                                is_error: tool_result.is_error,
 798                                content: tool_result.content.clone(),
 799                            })
 800                            .collect(),
 801                        context: message.context.clone(),
 802                    })
 803                    .collect(),
 804                initial_project_snapshot,
 805                cumulative_token_usage: this.cumulative_token_usage.clone(),
 806                detailed_summary_state: this.detailed_summary_state.clone(),
 807            })
 808        })
 809    }
 810
 811    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 812        self.system_prompt_context = Some(context);
 813    }
 814
 815    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 816        &self.system_prompt_context
 817    }
 818
 819    pub fn load_system_prompt_context(
 820        &self,
 821        cx: &App,
 822    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 823        let project = self.project.read(cx);
 824        let tasks = project
 825            .visible_worktrees(cx)
 826            .map(|worktree| {
 827                Self::load_worktree_info_for_system_prompt(
 828                    project.fs().clone(),
 829                    worktree.read(cx),
 830                    cx,
 831                )
 832            })
 833            .collect::<Vec<_>>();
 834
 835        cx.spawn(async |_cx| {
 836            let results = futures::future::join_all(tasks).await;
 837            let mut first_err = None;
 838            let worktrees = results
 839                .into_iter()
 840                .map(|(worktree, err)| {
 841                    if first_err.is_none() && err.is_some() {
 842                        first_err = err;
 843                    }
 844                    worktree
 845                })
 846                .collect::<Vec<_>>();
 847            (AssistantSystemPromptContext::new(worktrees), first_err)
 848        })
 849    }
 850
 851    fn load_worktree_info_for_system_prompt(
 852        fs: Arc<dyn Fs>,
 853        worktree: &Worktree,
 854        cx: &App,
 855    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 856        let root_name = worktree.root_name().into();
 857        let abs_path = worktree.abs_path();
 858
 859        let rules_task = load_worktree_rules_file(fs, worktree, cx);
 860        let Some(rules_task) = rules_task else {
 861            return Task::ready((
 862                WorktreeInfoForSystemPrompt {
 863                    root_name,
 864                    abs_path,
 865                    rules_file: None,
 866                },
 867                None,
 868            ));
 869        };
 870
 871        cx.spawn(async move |_| {
 872            let (rules_file, rules_file_error) = match rules_task.await {
 873                Ok(rules_file) => (Some(rules_file), None),
 874                Err(err) => (
 875                    None,
 876                    Some(ThreadError::Message {
 877                        header: "Error loading rules file".into(),
 878                        message: format!("{err}").into(),
 879                    }),
 880                ),
 881            };
 882            let worktree_info = WorktreeInfoForSystemPrompt {
 883                root_name,
 884                abs_path,
 885                rules_file,
 886            };
 887            (worktree_info, rules_file_error)
 888        })
 889    }
 890
 891    pub fn send_to_model(
 892        &mut self,
 893        model: Arc<dyn LanguageModel>,
 894        request_kind: RequestKind,
 895        cx: &mut Context<Self>,
 896    ) {
 897        let mut request = self.to_completion_request(request_kind, cx);
 898        if model.supports_tools() {
 899            request.tools = {
 900                let mut tools = Vec::new();
 901                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 902                    LanguageModelRequestTool {
 903                        name: tool.name(),
 904                        description: tool.description(),
 905                        input_schema: tool.input_schema(model.tool_input_format()),
 906                    }
 907                }));
 908
 909                tools
 910            };
 911        }
 912
 913        self.stream_completion(request, model, cx);
 914    }
 915
 916    pub fn used_tools_since_last_user_message(&self) -> bool {
 917        for message in self.messages.iter().rev() {
 918            if self.tool_use.message_has_tool_results(message.id) {
 919                return true;
 920            } else if message.role == Role::User {
 921                return false;
 922            }
 923        }
 924
 925        false
 926    }
 927
 928    pub fn to_completion_request(
 929        &self,
 930        request_kind: RequestKind,
 931        cx: &App,
 932    ) -> LanguageModelRequest {
 933        let mut request = LanguageModelRequest {
 934            messages: vec![],
 935            tools: Vec::new(),
 936            stop: Vec::new(),
 937            temperature: None,
 938        };
 939
 940        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 941            if let Some(system_prompt) = self
 942                .prompt_builder
 943                .generate_assistant_system_prompt(system_prompt_context)
 944                .context("failed to generate assistant system prompt")
 945                .log_err()
 946            {
 947                request.messages.push(LanguageModelRequestMessage {
 948                    role: Role::System,
 949                    content: vec![MessageContent::Text(system_prompt)],
 950                    cache: true,
 951                });
 952            }
 953        } else {
 954            log::error!("system_prompt_context not set.")
 955        }
 956
 957        for message in &self.messages {
 958            let mut request_message = LanguageModelRequestMessage {
 959                role: message.role,
 960                content: Vec::new(),
 961                cache: false,
 962            };
 963
 964            match request_kind {
 965                RequestKind::Chat => {
 966                    self.tool_use
 967                        .attach_tool_results(message.id, &mut request_message);
 968                }
 969                RequestKind::Summarize => {
 970                    // We don't care about tool use during summarization.
 971                    if self.tool_use.message_has_tool_results(message.id) {
 972                        continue;
 973                    }
 974                }
 975            }
 976
 977            if !message.segments.is_empty() {
 978                request_message
 979                    .content
 980                    .push(MessageContent::Text(message.to_string()));
 981            }
 982
 983            match request_kind {
 984                RequestKind::Chat => {
 985                    self.tool_use
 986                        .attach_tool_uses(message.id, &mut request_message);
 987                }
 988                RequestKind::Summarize => {
 989                    // We don't care about tool use during summarization.
 990                }
 991            };
 992
 993            request.messages.push(request_message);
 994        }
 995
 996        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 997        if let Some(last) = request.messages.last_mut() {
 998            last.cache = true;
 999        }
1000
1001        self.attached_tracked_files_state(&mut request.messages, cx);
1002
1003        // Add reminder to the last user message about code blocks
1004        if let Some(last_user_message) = request
1005            .messages
1006            .iter_mut()
1007            .rev()
1008            .find(|msg| msg.role == Role::User)
1009        {
1010            last_user_message
1011                .content
1012                .push(MessageContent::Text(system_prompt_reminder(
1013                    &self.prompt_builder,
1014                )));
1015        }
1016
1017        request
1018    }
1019
1020    fn attached_tracked_files_state(
1021        &self,
1022        messages: &mut Vec<LanguageModelRequestMessage>,
1023        cx: &App,
1024    ) {
1025        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1026
1027        let mut stale_message = String::new();
1028
1029        let action_log = self.action_log.read(cx);
1030
1031        for stale_file in action_log.stale_buffers(cx) {
1032            let Some(file) = stale_file.read(cx).file() else {
1033                continue;
1034            };
1035
1036            if stale_message.is_empty() {
1037                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1038            }
1039
1040            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1041        }
1042
1043        let mut content = Vec::with_capacity(2);
1044
1045        if !stale_message.is_empty() {
1046            content.push(stale_message.into());
1047        }
1048
1049        if action_log.has_edited_files_since_project_diagnostics_check() {
1050            content.push(
1051                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1052                and fix all errors AND warnings you introduced! \
1053                DO NOT mention you're going to do this until you're done."
1054                    .into(),
1055            );
1056        }
1057
1058        if !content.is_empty() {
1059            let context_message = LanguageModelRequestMessage {
1060                role: Role::User,
1061                content,
1062                cache: false,
1063            };
1064
1065            messages.push(context_message);
1066        }
1067    }
1068
1069    pub fn stream_completion(
1070        &mut self,
1071        request: LanguageModelRequest,
1072        model: Arc<dyn LanguageModel>,
1073        cx: &mut Context<Self>,
1074    ) {
1075        let pending_completion_id = post_inc(&mut self.completion_count);
1076
1077        let task = cx.spawn(async move |thread, cx| {
1078            let stream = model.stream_completion(request, &cx);
1079            let initial_token_usage =
1080                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1081            let stream_completion = async {
1082                let mut events = stream.await?;
1083                let mut stop_reason = StopReason::EndTurn;
1084                let mut current_token_usage = TokenUsage::default();
1085
1086                while let Some(event) = events.next().await {
1087                    let event = event?;
1088
1089                    thread.update(cx, |thread, cx| {
1090                        match event {
1091                            LanguageModelCompletionEvent::StartMessage { .. } => {
1092                                thread.insert_message(
1093                                    Role::Assistant,
1094                                    vec![MessageSegment::Text(String::new())],
1095                                    cx,
1096                                );
1097                            }
1098                            LanguageModelCompletionEvent::Stop(reason) => {
1099                                stop_reason = reason;
1100                            }
1101                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1102                                thread.cumulative_token_usage =
1103                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1104                                        - current_token_usage.clone();
1105                                current_token_usage = token_usage;
1106                            }
1107                            LanguageModelCompletionEvent::Text(chunk) => {
1108                                if let Some(last_message) = thread.messages.last_mut() {
1109                                    if last_message.role == Role::Assistant {
1110                                        last_message.push_text(&chunk);
1111                                        cx.emit(ThreadEvent::StreamedAssistantText(
1112                                            last_message.id,
1113                                            chunk,
1114                                        ));
1115                                    } else {
1116                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1117                                        // of a new Assistant response.
1118                                        //
1119                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1120                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1121                                        thread.insert_message(
1122                                            Role::Assistant,
1123                                            vec![MessageSegment::Text(chunk.to_string())],
1124                                            cx,
1125                                        );
1126                                    };
1127                                }
1128                            }
1129                            LanguageModelCompletionEvent::Thinking(chunk) => {
1130                                if let Some(last_message) = thread.messages.last_mut() {
1131                                    if last_message.role == Role::Assistant {
1132                                        last_message.push_thinking(&chunk);
1133                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1134                                            last_message.id,
1135                                            chunk,
1136                                        ));
1137                                    } else {
1138                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1139                                        // of a new Assistant response.
1140                                        //
1141                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1142                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1143                                        thread.insert_message(
1144                                            Role::Assistant,
1145                                            vec![MessageSegment::Thinking(chunk.to_string())],
1146                                            cx,
1147                                        );
1148                                    };
1149                                }
1150                            }
1151                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1152                                let last_assistant_message_id = thread
1153                                    .messages
1154                                    .iter_mut()
1155                                    .rfind(|message| message.role == Role::Assistant)
1156                                    .map(|message| message.id)
1157                                    .unwrap_or_else(|| {
1158                                        thread.insert_message(Role::Assistant, vec![], cx)
1159                                    });
1160
1161                                thread.tool_use.request_tool_use(
1162                                    last_assistant_message_id,
1163                                    tool_use,
1164                                    cx,
1165                                );
1166                            }
1167                        }
1168
1169                        thread.touch_updated_at();
1170                        cx.emit(ThreadEvent::StreamedCompletion);
1171                        cx.notify();
1172                    })?;
1173
1174                    smol::future::yield_now().await;
1175                }
1176
1177                thread.update(cx, |thread, cx| {
1178                    thread
1179                        .pending_completions
1180                        .retain(|completion| completion.id != pending_completion_id);
1181
1182                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1183                        thread.summarize(cx);
1184                    }
1185                })?;
1186
1187                anyhow::Ok(stop_reason)
1188            };
1189
1190            let result = stream_completion.await;
1191
1192            thread
1193                .update(cx, |thread, cx| {
1194                    thread.finalize_pending_checkpoint(cx);
1195                    match result.as_ref() {
1196                        Ok(stop_reason) => match stop_reason {
1197                            StopReason::ToolUse => {
1198                                cx.emit(ThreadEvent::UsePendingTools);
1199                            }
1200                            StopReason::EndTurn => {}
1201                            StopReason::MaxTokens => {}
1202                        },
1203                        Err(error) => {
1204                            if error.is::<PaymentRequiredError>() {
1205                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1206                            } else if error.is::<MaxMonthlySpendReachedError>() {
1207                                cx.emit(ThreadEvent::ShowError(
1208                                    ThreadError::MaxMonthlySpendReached,
1209                                ));
1210                            } else {
1211                                let error_message = error
1212                                    .chain()
1213                                    .map(|err| err.to_string())
1214                                    .collect::<Vec<_>>()
1215                                    .join("\n");
1216                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1217                                    header: "Error interacting with language model".into(),
1218                                    message: SharedString::from(error_message.clone()),
1219                                }));
1220                            }
1221
1222                            thread.cancel_last_completion(cx);
1223                        }
1224                    }
1225                    cx.emit(ThreadEvent::DoneStreaming);
1226
1227                    if let Ok(initial_usage) = initial_token_usage {
1228                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1229
1230                        telemetry::event!(
1231                            "Assistant Thread Completion",
1232                            thread_id = thread.id().to_string(),
1233                            model = model.telemetry_id(),
1234                            model_provider = model.provider_id().to_string(),
1235                            input_tokens = usage.input_tokens,
1236                            output_tokens = usage.output_tokens,
1237                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1238                            cache_read_input_tokens = usage.cache_read_input_tokens,
1239                        );
1240                    }
1241                })
1242                .ok();
1243        });
1244
1245        self.pending_completions.push(PendingCompletion {
1246            id: pending_completion_id,
1247            _task: task,
1248        });
1249    }
1250
1251    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1252        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1253            return;
1254        };
1255
1256        if !model.provider.is_authenticated(cx) {
1257            return;
1258        }
1259
1260        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1261        request.messages.push(LanguageModelRequestMessage {
1262            role: Role::User,
1263            content: vec![
1264                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1265                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1266                 If the conversation is about a specific subject, include it in the title. \
1267                 Be descriptive. DO NOT speak in the first person."
1268                    .into(),
1269            ],
1270            cache: false,
1271        });
1272
1273        self.pending_summary = cx.spawn(async move |this, cx| {
1274            async move {
1275                let stream = model.model.stream_completion_text(request, &cx);
1276                let mut messages = stream.await?;
1277
1278                let mut new_summary = String::new();
1279                while let Some(message) = messages.stream.next().await {
1280                    let text = message?;
1281                    let mut lines = text.lines();
1282                    new_summary.extend(lines.next());
1283
1284                    // Stop if the LLM generated multiple lines.
1285                    if lines.next().is_some() {
1286                        break;
1287                    }
1288                }
1289
1290                this.update(cx, |this, cx| {
1291                    if !new_summary.is_empty() {
1292                        this.summary = Some(new_summary.into());
1293                    }
1294
1295                    cx.emit(ThreadEvent::SummaryGenerated);
1296                })?;
1297
1298                anyhow::Ok(())
1299            }
1300            .log_err()
1301            .await
1302        });
1303    }
1304
1305    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1306        let last_message_id = self.messages.last().map(|message| message.id)?;
1307
1308        match &self.detailed_summary_state {
1309            DetailedSummaryState::Generating { message_id, .. }
1310            | DetailedSummaryState::Generated { message_id, .. }
1311                if *message_id == last_message_id =>
1312            {
1313                // Already up-to-date
1314                return None;
1315            }
1316            _ => {}
1317        }
1318
1319        let ConfiguredModel { model, provider } =
1320            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1321
1322        if !provider.is_authenticated(cx) {
1323            return None;
1324        }
1325
1326        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1327
1328        request.messages.push(LanguageModelRequestMessage {
1329            role: Role::User,
1330            content: vec![
1331                "Generate a detailed summary of this conversation. Include:\n\
1332                1. A brief overview of what was discussed\n\
1333                2. Key facts or information discovered\n\
1334                3. Outcomes or conclusions reached\n\
1335                4. Any action items or next steps if any\n\
1336                Format it in Markdown with headings and bullet points."
1337                    .into(),
1338            ],
1339            cache: false,
1340        });
1341
1342        let task = cx.spawn(async move |thread, cx| {
1343            let stream = model.stream_completion_text(request, &cx);
1344            let Some(mut messages) = stream.await.log_err() else {
1345                thread
1346                    .update(cx, |this, _cx| {
1347                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1348                    })
1349                    .log_err();
1350
1351                return;
1352            };
1353
1354            let mut new_detailed_summary = String::new();
1355
1356            while let Some(chunk) = messages.stream.next().await {
1357                if let Some(chunk) = chunk.log_err() {
1358                    new_detailed_summary.push_str(&chunk);
1359                }
1360            }
1361
1362            thread
1363                .update(cx, |this, _cx| {
1364                    this.detailed_summary_state = DetailedSummaryState::Generated {
1365                        text: new_detailed_summary.into(),
1366                        message_id: last_message_id,
1367                    };
1368                })
1369                .log_err();
1370        });
1371
1372        self.detailed_summary_state = DetailedSummaryState::Generating {
1373            message_id: last_message_id,
1374        };
1375
1376        Some(task)
1377    }
1378
1379    pub fn is_generating_detailed_summary(&self) -> bool {
1380        matches!(
1381            self.detailed_summary_state,
1382            DetailedSummaryState::Generating { .. }
1383        )
1384    }
1385
1386    pub fn use_pending_tools(
1387        &mut self,
1388        cx: &mut Context<Self>,
1389    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1390        let request = self.to_completion_request(RequestKind::Chat, cx);
1391        let messages = Arc::new(request.messages);
1392        let pending_tool_uses = self
1393            .tool_use
1394            .pending_tool_uses()
1395            .into_iter()
1396            .filter(|tool_use| tool_use.status.is_idle())
1397            .cloned()
1398            .collect::<Vec<_>>();
1399
1400        for tool_use in pending_tool_uses.iter() {
1401            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1402                if tool.needs_confirmation(&tool_use.input, cx)
1403                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1404                {
1405                    self.tool_use.confirm_tool_use(
1406                        tool_use.id.clone(),
1407                        tool_use.ui_text.clone(),
1408                        tool_use.input.clone(),
1409                        messages.clone(),
1410                        tool,
1411                    );
1412                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1413                } else {
1414                    self.run_tool(
1415                        tool_use.id.clone(),
1416                        tool_use.ui_text.clone(),
1417                        tool_use.input.clone(),
1418                        &messages,
1419                        tool,
1420                        cx,
1421                    );
1422                }
1423            }
1424        }
1425
1426        pending_tool_uses
1427    }
1428
1429    pub fn run_tool(
1430        &mut self,
1431        tool_use_id: LanguageModelToolUseId,
1432        ui_text: impl Into<SharedString>,
1433        input: serde_json::Value,
1434        messages: &[LanguageModelRequestMessage],
1435        tool: Arc<dyn Tool>,
1436        cx: &mut Context<Thread>,
1437    ) {
1438        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1439        self.tool_use
1440            .run_pending_tool(tool_use_id, ui_text.into(), task);
1441    }
1442
1443    fn spawn_tool_use(
1444        &mut self,
1445        tool_use_id: LanguageModelToolUseId,
1446        messages: &[LanguageModelRequestMessage],
1447        input: serde_json::Value,
1448        tool: Arc<dyn Tool>,
1449        cx: &mut Context<Thread>,
1450    ) -> Task<()> {
1451        let tool_name: Arc<str> = tool.name().into();
1452
1453        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1454            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1455        } else {
1456            tool.run(
1457                input,
1458                messages,
1459                self.project.clone(),
1460                self.action_log.clone(),
1461                cx,
1462            )
1463        };
1464
1465        cx.spawn({
1466            async move |thread: WeakEntity<Thread>, cx| {
1467                let output = run_tool.await;
1468
1469                thread
1470                    .update(cx, |thread, cx| {
1471                        let pending_tool_use = thread.tool_use.insert_tool_output(
1472                            tool_use_id.clone(),
1473                            tool_name,
1474                            output,
1475                            cx,
1476                        );
1477
1478                        cx.emit(ThreadEvent::ToolFinished {
1479                            tool_use_id,
1480                            pending_tool_use,
1481                            canceled: false,
1482                        });
1483                    })
1484                    .ok();
1485            }
1486        })
1487    }
1488
1489    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1490        // Insert a user message to contain the tool results.
1491        self.insert_user_message(
1492            // TODO: Sending up a user message without any content results in the model sending back
1493            // responses that also don't have any content. We currently don't handle this case well,
1494            // so for now we provide some text to keep the model on track.
1495            "Here are the tool results.",
1496            Vec::new(),
1497            None,
1498            cx,
1499        );
1500    }
1501
1502    /// Cancels the last pending completion, if there are any pending.
1503    ///
1504    /// Returns whether a completion was canceled.
1505    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1506        let canceled = if self.pending_completions.pop().is_some() {
1507            true
1508        } else {
1509            let mut canceled = false;
1510            for pending_tool_use in self.tool_use.cancel_pending() {
1511                canceled = true;
1512                cx.emit(ThreadEvent::ToolFinished {
1513                    tool_use_id: pending_tool_use.id.clone(),
1514                    pending_tool_use: Some(pending_tool_use),
1515                    canceled: true,
1516                });
1517            }
1518            canceled
1519        };
1520        self.finalize_pending_checkpoint(cx);
1521        canceled
1522    }
1523
1524    pub fn feedback(&self) -> Option<ThreadFeedback> {
1525        self.feedback
1526    }
1527
1528    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1529        self.message_feedback.get(&message_id).copied()
1530    }
1531
1532    pub fn report_message_feedback(
1533        &mut self,
1534        message_id: MessageId,
1535        feedback: ThreadFeedback,
1536        cx: &mut Context<Self>,
1537    ) -> Task<Result<()>> {
1538        if self.message_feedback.get(&message_id) == Some(&feedback) {
1539            return Task::ready(Ok(()));
1540        }
1541
1542        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1543        let serialized_thread = self.serialize(cx);
1544        let thread_id = self.id().clone();
1545        let client = self.project.read(cx).client();
1546
1547        self.message_feedback.insert(message_id, feedback);
1548
1549        cx.notify();
1550
1551        let message_content = self
1552            .message(message_id)
1553            .map(|msg| msg.to_string())
1554            .unwrap_or_default();
1555
1556        cx.background_spawn(async move {
1557            let final_project_snapshot = final_project_snapshot.await;
1558            let serialized_thread = serialized_thread.await?;
1559            let thread_data =
1560                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1561
1562            let rating = match feedback {
1563                ThreadFeedback::Positive => "positive",
1564                ThreadFeedback::Negative => "negative",
1565            };
1566            telemetry::event!(
1567                "Assistant Thread Rated",
1568                rating,
1569                thread_id,
1570                message_id = message_id.0,
1571                message_content,
1572                thread_data,
1573                final_project_snapshot
1574            );
1575            client.telemetry().flush_events();
1576
1577            Ok(())
1578        })
1579    }
1580
1581    pub fn report_feedback(
1582        &mut self,
1583        feedback: ThreadFeedback,
1584        cx: &mut Context<Self>,
1585    ) -> Task<Result<()>> {
1586        let last_assistant_message_id = self
1587            .messages
1588            .iter()
1589            .rev()
1590            .find(|msg| msg.role == Role::Assistant)
1591            .map(|msg| msg.id);
1592
1593        if let Some(message_id) = last_assistant_message_id {
1594            self.report_message_feedback(message_id, feedback, cx)
1595        } else {
1596            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1597            let serialized_thread = self.serialize(cx);
1598            let thread_id = self.id().clone();
1599            let client = self.project.read(cx).client();
1600            self.feedback = Some(feedback);
1601            cx.notify();
1602
1603            cx.background_spawn(async move {
1604                let final_project_snapshot = final_project_snapshot.await;
1605                let serialized_thread = serialized_thread.await?;
1606                let thread_data = serde_json::to_value(serialized_thread)
1607                    .unwrap_or_else(|_| serde_json::Value::Null);
1608
1609                let rating = match feedback {
1610                    ThreadFeedback::Positive => "positive",
1611                    ThreadFeedback::Negative => "negative",
1612                };
1613                telemetry::event!(
1614                    "Assistant Thread Rated",
1615                    rating,
1616                    thread_id,
1617                    thread_data,
1618                    final_project_snapshot
1619                );
1620                client.telemetry().flush_events();
1621
1622                Ok(())
1623            })
1624        }
1625    }
1626
1627    /// Create a snapshot of the current project state including git information and unsaved buffers.
1628    fn project_snapshot(
1629        project: Entity<Project>,
1630        cx: &mut Context<Self>,
1631    ) -> Task<Arc<ProjectSnapshot>> {
1632        let git_store = project.read(cx).git_store().clone();
1633        let worktree_snapshots: Vec<_> = project
1634            .read(cx)
1635            .visible_worktrees(cx)
1636            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1637            .collect();
1638
1639        cx.spawn(async move |_, cx| {
1640            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1641
1642            let mut unsaved_buffers = Vec::new();
1643            cx.update(|app_cx| {
1644                let buffer_store = project.read(app_cx).buffer_store();
1645                for buffer_handle in buffer_store.read(app_cx).buffers() {
1646                    let buffer = buffer_handle.read(app_cx);
1647                    if buffer.is_dirty() {
1648                        if let Some(file) = buffer.file() {
1649                            let path = file.path().to_string_lossy().to_string();
1650                            unsaved_buffers.push(path);
1651                        }
1652                    }
1653                }
1654            })
1655            .ok();
1656
1657            Arc::new(ProjectSnapshot {
1658                worktree_snapshots,
1659                unsaved_buffer_paths: unsaved_buffers,
1660                timestamp: Utc::now(),
1661            })
1662        })
1663    }
1664
1665    fn worktree_snapshot(
1666        worktree: Entity<project::Worktree>,
1667        git_store: Entity<GitStore>,
1668        cx: &App,
1669    ) -> Task<WorktreeSnapshot> {
1670        cx.spawn(async move |cx| {
1671            // Get worktree path and snapshot
1672            let worktree_info = cx.update(|app_cx| {
1673                let worktree = worktree.read(app_cx);
1674                let path = worktree.abs_path().to_string_lossy().to_string();
1675                let snapshot = worktree.snapshot();
1676                (path, snapshot)
1677            });
1678
1679            let Ok((worktree_path, _snapshot)) = worktree_info else {
1680                return WorktreeSnapshot {
1681                    worktree_path: String::new(),
1682                    git_state: None,
1683                };
1684            };
1685
1686            let git_state = git_store
1687                .update(cx, |git_store, cx| {
1688                    git_store
1689                        .repositories()
1690                        .values()
1691                        .find(|repo| {
1692                            repo.read(cx)
1693                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1694                                .is_some()
1695                        })
1696                        .cloned()
1697                })
1698                .ok()
1699                .flatten()
1700                .map(|repo| {
1701                    repo.update(cx, |repo, _| {
1702                        let current_branch =
1703                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1704                        repo.send_job(None, |state, _| async move {
1705                            let RepositoryState::Local { backend, .. } = state else {
1706                                return GitState {
1707                                    remote_url: None,
1708                                    head_sha: None,
1709                                    current_branch,
1710                                    diff: None,
1711                                };
1712                            };
1713
1714                            let remote_url = backend.remote_url("origin");
1715                            let head_sha = backend.head_sha();
1716                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1717
1718                            GitState {
1719                                remote_url,
1720                                head_sha,
1721                                current_branch,
1722                                diff,
1723                            }
1724                        })
1725                    })
1726                });
1727
1728            let git_state = match git_state {
1729                Some(git_state) => match git_state.ok() {
1730                    Some(git_state) => git_state.await.ok(),
1731                    None => None,
1732                },
1733                None => None,
1734            };
1735
1736            WorktreeSnapshot {
1737                worktree_path,
1738                git_state,
1739            }
1740        })
1741    }
1742
1743    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1744        let mut markdown = Vec::new();
1745
1746        if let Some(summary) = self.summary() {
1747            writeln!(markdown, "# {summary}\n")?;
1748        };
1749
1750        for message in self.messages() {
1751            writeln!(
1752                markdown,
1753                "## {role}\n",
1754                role = match message.role {
1755                    Role::User => "User",
1756                    Role::Assistant => "Assistant",
1757                    Role::System => "System",
1758                }
1759            )?;
1760
1761            if !message.context.is_empty() {
1762                writeln!(markdown, "{}", message.context)?;
1763            }
1764
1765            for segment in &message.segments {
1766                match segment {
1767                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1768                    MessageSegment::Thinking(text) => {
1769                        writeln!(markdown, "<think>{}</think>\n", text)?
1770                    }
1771                }
1772            }
1773
1774            for tool_use in self.tool_uses_for_message(message.id, cx) {
1775                writeln!(
1776                    markdown,
1777                    "**Use Tool: {} ({})**",
1778                    tool_use.name, tool_use.id
1779                )?;
1780                writeln!(markdown, "```json")?;
1781                writeln!(
1782                    markdown,
1783                    "{}",
1784                    serde_json::to_string_pretty(&tool_use.input)?
1785                )?;
1786                writeln!(markdown, "```")?;
1787            }
1788
1789            for tool_result in self.tool_results_for_message(message.id) {
1790                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1791                if tool_result.is_error {
1792                    write!(markdown, " (Error)")?;
1793                }
1794
1795                writeln!(markdown, "**\n")?;
1796                writeln!(markdown, "{}", tool_result.content)?;
1797            }
1798        }
1799
1800        Ok(String::from_utf8_lossy(&markdown).to_string())
1801    }
1802
1803    pub fn keep_edits_in_range(
1804        &mut self,
1805        buffer: Entity<language::Buffer>,
1806        buffer_range: Range<language::Anchor>,
1807        cx: &mut Context<Self>,
1808    ) {
1809        self.action_log.update(cx, |action_log, cx| {
1810            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1811        });
1812    }
1813
1814    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1815        self.action_log
1816            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1817    }
1818
1819    pub fn reject_edits_in_range(
1820        &mut self,
1821        buffer: Entity<language::Buffer>,
1822        buffer_range: Range<language::Anchor>,
1823        cx: &mut Context<Self>,
1824    ) -> Task<Result<()>> {
1825        self.action_log.update(cx, |action_log, cx| {
1826            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1827        })
1828    }
1829
1830    pub fn action_log(&self) -> &Entity<ActionLog> {
1831        &self.action_log
1832    }
1833
1834    pub fn project(&self) -> &Entity<Project> {
1835        &self.project
1836    }
1837
1838    pub fn cumulative_token_usage(&self) -> TokenUsage {
1839        self.cumulative_token_usage.clone()
1840    }
1841
1842    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1843        let model_registry = LanguageModelRegistry::read_global(cx);
1844        let Some(model) = model_registry.default_model() else {
1845            return TotalTokenUsage::default();
1846        };
1847
1848        let max = model.model.max_token_count();
1849
1850        #[cfg(debug_assertions)]
1851        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1852            .unwrap_or("0.8".to_string())
1853            .parse()
1854            .unwrap();
1855        #[cfg(not(debug_assertions))]
1856        let warning_threshold: f32 = 0.8;
1857
1858        let total = self.cumulative_token_usage.total_tokens() as usize;
1859
1860        let ratio = if total >= max {
1861            TokenUsageRatio::Exceeded
1862        } else if total as f32 / max as f32 >= warning_threshold {
1863            TokenUsageRatio::Warning
1864        } else {
1865            TokenUsageRatio::Normal
1866        };
1867
1868        TotalTokenUsage { total, max, ratio }
1869    }
1870
1871    pub fn deny_tool_use(
1872        &mut self,
1873        tool_use_id: LanguageModelToolUseId,
1874        tool_name: Arc<str>,
1875        cx: &mut Context<Self>,
1876    ) {
1877        let err = Err(anyhow::anyhow!(
1878            "Permission to run tool action denied by user"
1879        ));
1880
1881        self.tool_use
1882            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1883
1884        cx.emit(ThreadEvent::ToolFinished {
1885            tool_use_id,
1886            pending_tool_use: None,
1887            canceled: true,
1888        });
1889    }
1890}
1891
1892pub fn system_prompt_reminder(prompt_builder: &prompt_store::PromptBuilder) -> String {
1893    prompt_builder
1894        .generate_assistant_system_prompt_reminder()
1895        .unwrap_or_default()
1896}
1897
1898#[derive(Debug, Clone)]
1899pub enum ThreadError {
1900    PaymentRequired,
1901    MaxMonthlySpendReached,
1902    Message {
1903        header: SharedString,
1904        message: SharedString,
1905    },
1906}
1907
1908#[derive(Debug, Clone)]
1909pub enum ThreadEvent {
1910    ShowError(ThreadError),
1911    StreamedCompletion,
1912    StreamedAssistantText(MessageId, String),
1913    StreamedAssistantThinking(MessageId, String),
1914    DoneStreaming,
1915    MessageAdded(MessageId),
1916    MessageEdited(MessageId),
1917    MessageDeleted(MessageId),
1918    SummaryGenerated,
1919    SummaryChanged,
1920    UsePendingTools,
1921    ToolFinished {
1922        #[allow(unused)]
1923        tool_use_id: LanguageModelToolUseId,
1924        /// The pending tool use that corresponds to this tool.
1925        pending_tool_use: Option<PendingToolUse>,
1926        /// Whether the tool was canceled by the user.
1927        canceled: bool,
1928    },
1929    CheckpointChanged,
1930    ToolConfirmationNeeded,
1931}
1932
1933impl EventEmitter<ThreadEvent> for Thread {}
1934
1935struct PendingCompletion {
1936    id: usize,
1937    _task: Task<()>,
1938}
1939
1940#[cfg(test)]
1941mod tests {
1942    use super::*;
1943    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1944    use assistant_settings::AssistantSettings;
1945    use context_server::ContextServerSettings;
1946    use editor::EditorSettings;
1947    use gpui::TestAppContext;
1948    use project::{FakeFs, Project};
1949    use prompt_store::PromptBuilder;
1950    use serde_json::json;
1951    use settings::{Settings, SettingsStore};
1952    use std::sync::Arc;
1953    use theme::ThemeSettings;
1954    use util::path;
1955    use workspace::Workspace;
1956
1957    #[gpui::test]
1958    async fn test_message_with_context(cx: &mut TestAppContext) {
1959        init_test_settings(cx);
1960
1961        let project = create_test_project(
1962            cx,
1963            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1964        )
1965        .await;
1966
1967        let (_workspace, _thread_store, thread, context_store, prompt_builder) =
1968            setup_test_environment(cx, project.clone()).await;
1969
1970        add_file_to_context(&project, &context_store, "test/code.rs", cx)
1971            .await
1972            .unwrap();
1973
1974        let context =
1975            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1976
1977        // Insert user message with context
1978        let message_id = thread.update(cx, |thread, cx| {
1979            thread.insert_user_message("Please explain this code", vec![context], None, cx)
1980        });
1981
1982        // Check content and context in message object
1983        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1984
1985        // Use different path format strings based on platform for the test
1986        #[cfg(windows)]
1987        let path_part = r"test\code.rs";
1988        #[cfg(not(windows))]
1989        let path_part = "test/code.rs";
1990
1991        let expected_context = format!(
1992            r#"
1993<context>
1994The following items were attached by the user. You don't need to use other tools to read them.
1995
1996<files>
1997```rs {path_part}
1998fn main() {{
1999    println!("Hello, world!");
2000}}
2001```
2002</files>
2003</context>
2004"#
2005        );
2006
2007        assert_eq!(message.role, Role::User);
2008        assert_eq!(message.segments.len(), 1);
2009        assert_eq!(
2010            message.segments[0],
2011            MessageSegment::Text("Please explain this code".to_string())
2012        );
2013        assert_eq!(message.context, expected_context);
2014
2015        // Check message in request
2016        let request = thread.read_with(cx, |thread, cx| {
2017            thread.to_completion_request(RequestKind::Chat, cx)
2018        });
2019
2020        assert_eq!(request.messages.len(), 1);
2021        let actual_message = request.messages[0].string_contents();
2022        let expected_content = format!(
2023            "{}Please explain this code{}",
2024            expected_context,
2025            system_prompt_reminder(&prompt_builder)
2026        );
2027
2028        assert_eq!(actual_message, expected_content);
2029    }
2030
2031    #[gpui::test]
2032    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2033        init_test_settings(cx);
2034
2035        let project = create_test_project(
2036            cx,
2037            json!({
2038                "file1.rs": "fn function1() {}\n",
2039                "file2.rs": "fn function2() {}\n",
2040                "file3.rs": "fn function3() {}\n",
2041            }),
2042        )
2043        .await;
2044
2045        let (_, _thread_store, thread, context_store, _prompt_builder) =
2046            setup_test_environment(cx, project.clone()).await;
2047
2048        // Open files individually
2049        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2050            .await
2051            .unwrap();
2052        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2053            .await
2054            .unwrap();
2055        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2056            .await
2057            .unwrap();
2058
2059        // Get the context objects
2060        let contexts = context_store.update(cx, |store, _| store.context().clone());
2061        assert_eq!(contexts.len(), 3);
2062
2063        // First message with context 1
2064        let message1_id = thread.update(cx, |thread, cx| {
2065            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2066        });
2067
2068        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2069        let message2_id = thread.update(cx, |thread, cx| {
2070            thread.insert_user_message(
2071                "Message 2",
2072                vec![contexts[0].clone(), contexts[1].clone()],
2073                None,
2074                cx,
2075            )
2076        });
2077
2078        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2079        let message3_id = thread.update(cx, |thread, cx| {
2080            thread.insert_user_message(
2081                "Message 3",
2082                vec![
2083                    contexts[0].clone(),
2084                    contexts[1].clone(),
2085                    contexts[2].clone(),
2086                ],
2087                None,
2088                cx,
2089            )
2090        });
2091
2092        // Check what contexts are included in each message
2093        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2094            (
2095                thread.message(message1_id).unwrap().clone(),
2096                thread.message(message2_id).unwrap().clone(),
2097                thread.message(message3_id).unwrap().clone(),
2098            )
2099        });
2100
2101        // First message should include context 1
2102        assert!(message1.context.contains("file1.rs"));
2103
2104        // Second message should include only context 2 (not 1)
2105        assert!(!message2.context.contains("file1.rs"));
2106        assert!(message2.context.contains("file2.rs"));
2107
2108        // Third message should include only context 3 (not 1 or 2)
2109        assert!(!message3.context.contains("file1.rs"));
2110        assert!(!message3.context.contains("file2.rs"));
2111        assert!(message3.context.contains("file3.rs"));
2112
2113        // Check entire request to make sure all contexts are properly included
2114        let request = thread.read_with(cx, |thread, cx| {
2115            thread.to_completion_request(RequestKind::Chat, cx)
2116        });
2117
2118        // The request should contain all 3 messages
2119        assert_eq!(request.messages.len(), 3);
2120
2121        // Check that the contexts are properly formatted in each message
2122        assert!(request.messages[0].string_contents().contains("file1.rs"));
2123        assert!(!request.messages[0].string_contents().contains("file2.rs"));
2124        assert!(!request.messages[0].string_contents().contains("file3.rs"));
2125
2126        assert!(!request.messages[1].string_contents().contains("file1.rs"));
2127        assert!(request.messages[1].string_contents().contains("file2.rs"));
2128        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2129
2130        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2131        assert!(!request.messages[2].string_contents().contains("file2.rs"));
2132        assert!(request.messages[2].string_contents().contains("file3.rs"));
2133    }
2134
2135    #[gpui::test]
2136    async fn test_message_without_files(cx: &mut TestAppContext) {
2137        init_test_settings(cx);
2138
2139        let project = create_test_project(
2140            cx,
2141            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2142        )
2143        .await;
2144
2145        let (_, _thread_store, thread, _context_store, prompt_builder) =
2146            setup_test_environment(cx, project.clone()).await;
2147
2148        // Insert user message without any context (empty context vector)
2149        let message_id = thread.update(cx, |thread, cx| {
2150            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2151        });
2152
2153        // Check content and context in message object
2154        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2155
2156        // Context should be empty when no files are included
2157        assert_eq!(message.role, Role::User);
2158        assert_eq!(message.segments.len(), 1);
2159        assert_eq!(
2160            message.segments[0],
2161            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2162        );
2163        assert_eq!(message.context, "");
2164
2165        // Check message in request
2166        let request = thread.read_with(cx, |thread, cx| {
2167            thread.to_completion_request(RequestKind::Chat, cx)
2168        });
2169
2170        assert_eq!(request.messages.len(), 1);
2171        let actual_message = request.messages[0].string_contents();
2172        let expected_content = format!(
2173            "What is the best way to learn Rust?{}",
2174            system_prompt_reminder(&prompt_builder)
2175        );
2176
2177        assert_eq!(actual_message, expected_content);
2178
2179        // Add second message, also without context
2180        let message2_id = thread.update(cx, |thread, cx| {
2181            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2182        });
2183
2184        let message2 =
2185            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2186        assert_eq!(message2.context, "");
2187
2188        // Check that both messages appear in the request
2189        let request = thread.read_with(cx, |thread, cx| {
2190            thread.to_completion_request(RequestKind::Chat, cx)
2191        });
2192
2193        assert_eq!(request.messages.len(), 2);
2194        // First message should be the system prompt
2195        assert_eq!(request.messages[0].role, Role::User);
2196
2197        // Second message should be the user message with prompt reminder
2198        let actual_message = request.messages[1].string_contents();
2199        let expected_content = format!(
2200            "Are there any good books?{}",
2201            system_prompt_reminder(&prompt_builder)
2202        );
2203
2204        assert_eq!(actual_message, expected_content);
2205    }
2206
2207    #[gpui::test]
2208    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2209        init_test_settings(cx);
2210
2211        let project = create_test_project(
2212            cx,
2213            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2214        )
2215        .await;
2216
2217        let (_workspace, _thread_store, thread, context_store, prompt_builder) =
2218            setup_test_environment(cx, project.clone()).await;
2219
2220        // Open buffer and add it to context
2221        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2222            .await
2223            .unwrap();
2224
2225        let context =
2226            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2227
2228        // Insert user message with the buffer as context
2229        thread.update(cx, |thread, cx| {
2230            thread.insert_user_message("Explain this code", vec![context], None, cx)
2231        });
2232
2233        // Create a request and check that it doesn't have a stale buffer warning yet
2234        let initial_request = thread.read_with(cx, |thread, cx| {
2235            thread.to_completion_request(RequestKind::Chat, cx)
2236        });
2237
2238        // Make sure we don't have a stale file warning yet
2239        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2240            msg.string_contents()
2241                .contains("These files changed since last read:")
2242        });
2243        assert!(
2244            !has_stale_warning,
2245            "Should not have stale buffer warning before buffer is modified"
2246        );
2247
2248        // Modify the buffer
2249        buffer.update(cx, |buffer, cx| {
2250            // Find a position at the end of line 1
2251            buffer.edit(
2252                [(1..1, "\n    println!(\"Added a new line\");\n")],
2253                None,
2254                cx,
2255            );
2256        });
2257
2258        // Insert another user message without context
2259        thread.update(cx, |thread, cx| {
2260            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2261        });
2262
2263        // Create a new request and check for the stale buffer warning
2264        let new_request = thread.read_with(cx, |thread, cx| {
2265            thread.to_completion_request(RequestKind::Chat, cx)
2266        });
2267
2268        // We should have a stale file warning as the last message
2269        let last_message = new_request
2270            .messages
2271            .last()
2272            .expect("Request should have messages");
2273
2274        // The last message should be the stale buffer notification
2275        assert_eq!(last_message.role, Role::User);
2276
2277        let actual_message = last_message.string_contents();
2278        let expected_content = format!(
2279            "These files changed since last read:\n- code.rs\n{}",
2280            system_prompt_reminder(&prompt_builder)
2281        );
2282
2283        assert_eq!(
2284            actual_message, expected_content,
2285            "Last message should be exactly the stale buffer notification"
2286        );
2287    }
2288
2289    fn init_test_settings(cx: &mut TestAppContext) {
2290        cx.update(|cx| {
2291            let settings_store = SettingsStore::test(cx);
2292            cx.set_global(settings_store);
2293            language::init(cx);
2294            Project::init_settings(cx);
2295            AssistantSettings::register(cx);
2296            thread_store::init(cx);
2297            workspace::init_settings(cx);
2298            ThemeSettings::register(cx);
2299            ContextServerSettings::register(cx);
2300            EditorSettings::register(cx);
2301        });
2302    }
2303
2304    // Helper to create a test project with test files
2305    async fn create_test_project(
2306        cx: &mut TestAppContext,
2307        files: serde_json::Value,
2308    ) -> Entity<Project> {
2309        let fs = FakeFs::new(cx.executor());
2310        fs.insert_tree(path!("/test"), files).await;
2311        Project::test(fs, [path!("/test").as_ref()], cx).await
2312    }
2313
2314    async fn setup_test_environment(
2315        cx: &mut TestAppContext,
2316        project: Entity<Project>,
2317    ) -> (
2318        Entity<Workspace>,
2319        Entity<ThreadStore>,
2320        Entity<Thread>,
2321        Entity<ContextStore>,
2322        Arc<PromptBuilder>,
2323    ) {
2324        let (workspace, cx) =
2325            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2326
2327        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
2328
2329        let thread_store = cx.update(|_, cx| {
2330            ThreadStore::new(project.clone(), Arc::default(), prompt_builder.clone(), cx).unwrap()
2331        });
2332
2333        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2334        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2335
2336        (
2337            workspace,
2338            thread_store,
2339            thread,
2340            context_store,
2341            prompt_builder,
2342        )
2343    }
2344
2345    async fn add_file_to_context(
2346        project: &Entity<Project>,
2347        context_store: &Entity<ContextStore>,
2348        path: &str,
2349        cx: &mut TestAppContext,
2350    ) -> Result<Entity<language::Buffer>> {
2351        let buffer_path = project
2352            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2353            .unwrap();
2354
2355        let buffer = project
2356            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2357            .await
2358            .unwrap();
2359
2360        context_store
2361            .update(cx, |store, cx| {
2362                store.add_file_from_buffer(buffer.clone(), cx)
2363            })
2364            .await?;
2365
2366        Ok(buffer)
2367    }
2368}