thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result, anyhow};
   7use assistant_settings::AssistantSettings;
   8use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::{BTreeMap, HashMap};
  11use feature_flags::{self, FeatureFlagAppExt};
  12use futures::future::Shared;
  13use futures::{FutureExt, StreamExt as _};
  14use git::repository::DiffType;
  15use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  16use language_model::{
  17    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
  18    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  19    LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  20    PaymentRequiredError, Role, StopReason, TokenUsage,
  21};
  22use project::Project;
  23use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  24use prompt_store::PromptBuilder;
  25use schemars::JsonSchema;
  26use serde::{Deserialize, Serialize};
  27use settings::Settings;
  28use thiserror::Error;
  29use util::{ResultExt as _, TryFutureExt as _, post_inc};
  30use uuid::Uuid;
  31
  32use crate::context::{AssistantContext, ContextId, format_context_as_string};
  33use crate::thread_store::{
  34    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  35    SerializedToolUse, SharedProjectContext,
  36};
  37use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  38
  39#[derive(Debug, Clone, Copy)]
  40pub enum RequestKind {
  41    Chat,
  42    /// Used when summarizing a thread.
  43    Summarize,
  44}
  45
  46#[derive(
  47    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  48)]
  49pub struct ThreadId(Arc<str>);
  50
  51impl ThreadId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for ThreadId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63impl From<&str> for ThreadId {
  64    fn from(value: &str) -> Self {
  65        Self(value.into())
  66    }
  67}
  68
  69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  70pub struct MessageId(pub(crate) usize);
  71
  72impl MessageId {
  73    fn post_inc(&mut self) -> Self {
  74        Self(post_inc(&mut self.0))
  75    }
  76}
  77
  78/// A message in a [`Thread`].
  79#[derive(Debug, Clone)]
  80pub struct Message {
  81    pub id: MessageId,
  82    pub role: Role,
  83    pub segments: Vec<MessageSegment>,
  84    pub context: String,
  85}
  86
  87impl Message {
  88    /// Returns whether the message contains any meaningful text that should be displayed
  89    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  90    pub fn should_display_content(&self) -> bool {
  91        self.segments.iter().all(|segment| segment.should_display())
  92    }
  93
  94    pub fn push_thinking(&mut self, text: &str) {
  95        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  96            segment.push_str(text);
  97        } else {
  98            self.segments
  99                .push(MessageSegment::Thinking(text.to_string()));
 100        }
 101    }
 102
 103    pub fn push_text(&mut self, text: &str) {
 104        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 105            segment.push_str(text);
 106        } else {
 107            self.segments.push(MessageSegment::Text(text.to_string()));
 108        }
 109    }
 110
 111    pub fn to_string(&self) -> String {
 112        let mut result = String::new();
 113
 114        if !self.context.is_empty() {
 115            result.push_str(&self.context);
 116        }
 117
 118        for segment in &self.segments {
 119            match segment {
 120                MessageSegment::Text(text) => result.push_str(text),
 121                MessageSegment::Thinking(text) => {
 122                    result.push_str("<think>");
 123                    result.push_str(text);
 124                    result.push_str("</think>");
 125                }
 126            }
 127        }
 128
 129        result
 130    }
 131}
 132
 133#[derive(Debug, Clone, PartialEq, Eq)]
 134pub enum MessageSegment {
 135    Text(String),
 136    Thinking(String),
 137}
 138
 139impl MessageSegment {
 140    pub fn text_mut(&mut self) -> &mut String {
 141        match self {
 142            Self::Text(text) => text,
 143            Self::Thinking(text) => text,
 144        }
 145    }
 146
 147    pub fn should_display(&self) -> bool {
 148        // We add USING_TOOL_MARKER when making a request that includes tool uses
 149        // without non-whitespace text around them, and this can cause the model
 150        // to mimic the pattern, so we consider those segments not displayable.
 151        match self {
 152            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 153            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 154        }
 155    }
 156}
 157
 158#[derive(Debug, Clone, Serialize, Deserialize)]
 159pub struct ProjectSnapshot {
 160    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 161    pub unsaved_buffer_paths: Vec<String>,
 162    pub timestamp: DateTime<Utc>,
 163}
 164
 165#[derive(Debug, Clone, Serialize, Deserialize)]
 166pub struct WorktreeSnapshot {
 167    pub worktree_path: String,
 168    pub git_state: Option<GitState>,
 169}
 170
 171#[derive(Debug, Clone, Serialize, Deserialize)]
 172pub struct GitState {
 173    pub remote_url: Option<String>,
 174    pub head_sha: Option<String>,
 175    pub current_branch: Option<String>,
 176    pub diff: Option<String>,
 177}
 178
 179#[derive(Clone)]
 180pub struct ThreadCheckpoint {
 181    message_id: MessageId,
 182    git_checkpoint: GitStoreCheckpoint,
 183}
 184
 185#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 186pub enum ThreadFeedback {
 187    Positive,
 188    Negative,
 189}
 190
 191pub enum LastRestoreCheckpoint {
 192    Pending {
 193        message_id: MessageId,
 194    },
 195    Error {
 196        message_id: MessageId,
 197        error: String,
 198    },
 199}
 200
 201impl LastRestoreCheckpoint {
 202    pub fn message_id(&self) -> MessageId {
 203        match self {
 204            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 205            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 206        }
 207    }
 208}
 209
 210#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 211pub enum DetailedSummaryState {
 212    #[default]
 213    NotGenerated,
 214    Generating {
 215        message_id: MessageId,
 216    },
 217    Generated {
 218        text: SharedString,
 219        message_id: MessageId,
 220    },
 221}
 222
 223#[derive(Default)]
 224pub struct TotalTokenUsage {
 225    pub total: usize,
 226    pub max: usize,
 227    pub ratio: TokenUsageRatio,
 228}
 229
 230#[derive(Default, PartialEq, Eq)]
 231pub enum TokenUsageRatio {
 232    #[default]
 233    Normal,
 234    Warning,
 235    Exceeded,
 236}
 237
 238/// A thread of conversation with the LLM.
 239pub struct Thread {
 240    id: ThreadId,
 241    updated_at: DateTime<Utc>,
 242    summary: Option<SharedString>,
 243    pending_summary: Task<Option<()>>,
 244    detailed_summary_state: DetailedSummaryState,
 245    messages: Vec<Message>,
 246    next_message_id: MessageId,
 247    context: BTreeMap<ContextId, AssistantContext>,
 248    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 249    project_context: SharedProjectContext,
 250    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 251    completion_count: usize,
 252    pending_completions: Vec<PendingCompletion>,
 253    project: Entity<Project>,
 254    prompt_builder: Arc<PromptBuilder>,
 255    tools: Arc<ToolWorkingSet>,
 256    tool_use: ToolUseState,
 257    action_log: Entity<ActionLog>,
 258    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 259    pending_checkpoint: Option<ThreadCheckpoint>,
 260    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 261    cumulative_token_usage: TokenUsage,
 262    feedback: Option<ThreadFeedback>,
 263    message_feedback: HashMap<MessageId, ThreadFeedback>,
 264}
 265
 266impl Thread {
 267    pub fn new(
 268        project: Entity<Project>,
 269        tools: Arc<ToolWorkingSet>,
 270        prompt_builder: Arc<PromptBuilder>,
 271        system_prompt: SharedProjectContext,
 272        cx: &mut Context<Self>,
 273    ) -> Self {
 274        Self {
 275            id: ThreadId::new(),
 276            updated_at: Utc::now(),
 277            summary: None,
 278            pending_summary: Task::ready(None),
 279            detailed_summary_state: DetailedSummaryState::NotGenerated,
 280            messages: Vec::new(),
 281            next_message_id: MessageId(0),
 282            context: BTreeMap::default(),
 283            context_by_message: HashMap::default(),
 284            project_context: system_prompt,
 285            checkpoints_by_message: HashMap::default(),
 286            completion_count: 0,
 287            pending_completions: Vec::new(),
 288            project: project.clone(),
 289            prompt_builder,
 290            tools: tools.clone(),
 291            last_restore_checkpoint: None,
 292            pending_checkpoint: None,
 293            tool_use: ToolUseState::new(tools.clone()),
 294            action_log: cx.new(|_| ActionLog::new(project.clone())),
 295            initial_project_snapshot: {
 296                let project_snapshot = Self::project_snapshot(project, cx);
 297                cx.foreground_executor()
 298                    .spawn(async move { Some(project_snapshot.await) })
 299                    .shared()
 300            },
 301            cumulative_token_usage: TokenUsage::default(),
 302            feedback: None,
 303            message_feedback: HashMap::default(),
 304        }
 305    }
 306
 307    pub fn deserialize(
 308        id: ThreadId,
 309        serialized: SerializedThread,
 310        project: Entity<Project>,
 311        tools: Arc<ToolWorkingSet>,
 312        prompt_builder: Arc<PromptBuilder>,
 313        project_context: SharedProjectContext,
 314        cx: &mut Context<Self>,
 315    ) -> Self {
 316        let next_message_id = MessageId(
 317            serialized
 318                .messages
 319                .last()
 320                .map(|message| message.id.0 + 1)
 321                .unwrap_or(0),
 322        );
 323        let tool_use =
 324            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 325
 326        Self {
 327            id,
 328            updated_at: serialized.updated_at,
 329            summary: Some(serialized.summary),
 330            pending_summary: Task::ready(None),
 331            detailed_summary_state: serialized.detailed_summary_state,
 332            messages: serialized
 333                .messages
 334                .into_iter()
 335                .map(|message| Message {
 336                    id: message.id,
 337                    role: message.role,
 338                    segments: message
 339                        .segments
 340                        .into_iter()
 341                        .map(|segment| match segment {
 342                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 343                            SerializedMessageSegment::Thinking { text } => {
 344                                MessageSegment::Thinking(text)
 345                            }
 346                        })
 347                        .collect(),
 348                    context: message.context,
 349                })
 350                .collect(),
 351            next_message_id,
 352            context: BTreeMap::default(),
 353            context_by_message: HashMap::default(),
 354            project_context,
 355            checkpoints_by_message: HashMap::default(),
 356            completion_count: 0,
 357            pending_completions: Vec::new(),
 358            last_restore_checkpoint: None,
 359            pending_checkpoint: None,
 360            project: project.clone(),
 361            prompt_builder,
 362            tools,
 363            tool_use,
 364            action_log: cx.new(|_| ActionLog::new(project)),
 365            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 366            cumulative_token_usage: serialized.cumulative_token_usage,
 367            feedback: None,
 368            message_feedback: HashMap::default(),
 369        }
 370    }
 371
 372    pub fn id(&self) -> &ThreadId {
 373        &self.id
 374    }
 375
 376    pub fn is_empty(&self) -> bool {
 377        self.messages.is_empty()
 378    }
 379
 380    pub fn updated_at(&self) -> DateTime<Utc> {
 381        self.updated_at
 382    }
 383
 384    pub fn touch_updated_at(&mut self) {
 385        self.updated_at = Utc::now();
 386    }
 387
 388    pub fn summary(&self) -> Option<SharedString> {
 389        self.summary.clone()
 390    }
 391
 392    pub fn project_context(&self) -> SharedProjectContext {
 393        self.project_context.clone()
 394    }
 395
 396    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 397
 398    pub fn summary_or_default(&self) -> SharedString {
 399        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 400    }
 401
 402    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 403        let Some(current_summary) = &self.summary else {
 404            // Don't allow setting summary until generated
 405            return;
 406        };
 407
 408        let mut new_summary = new_summary.into();
 409
 410        if new_summary.is_empty() {
 411            new_summary = Self::DEFAULT_SUMMARY;
 412        }
 413
 414        if current_summary != &new_summary {
 415            self.summary = Some(new_summary);
 416            cx.emit(ThreadEvent::SummaryChanged);
 417        }
 418    }
 419
 420    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 421        self.latest_detailed_summary()
 422            .unwrap_or_else(|| self.text().into())
 423    }
 424
 425    fn latest_detailed_summary(&self) -> Option<SharedString> {
 426        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 427            Some(text.clone())
 428        } else {
 429            None
 430        }
 431    }
 432
 433    pub fn message(&self, id: MessageId) -> Option<&Message> {
 434        self.messages.iter().find(|message| message.id == id)
 435    }
 436
 437    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 438        self.messages.iter()
 439    }
 440
 441    pub fn is_generating(&self) -> bool {
 442        !self.pending_completions.is_empty() || !self.all_tools_finished()
 443    }
 444
 445    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 446        &self.tools
 447    }
 448
 449    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 450        self.tool_use
 451            .pending_tool_uses()
 452            .into_iter()
 453            .find(|tool_use| &tool_use.id == id)
 454    }
 455
 456    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 457        self.tool_use
 458            .pending_tool_uses()
 459            .into_iter()
 460            .filter(|tool_use| tool_use.status.needs_confirmation())
 461    }
 462
 463    pub fn has_pending_tool_uses(&self) -> bool {
 464        !self.tool_use.pending_tool_uses().is_empty()
 465    }
 466
 467    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 468        self.checkpoints_by_message.get(&id).cloned()
 469    }
 470
 471    pub fn restore_checkpoint(
 472        &mut self,
 473        checkpoint: ThreadCheckpoint,
 474        cx: &mut Context<Self>,
 475    ) -> Task<Result<()>> {
 476        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 477            message_id: checkpoint.message_id,
 478        });
 479        cx.emit(ThreadEvent::CheckpointChanged);
 480        cx.notify();
 481
 482        let git_store = self.project().read(cx).git_store().clone();
 483        let restore = git_store.update(cx, |git_store, cx| {
 484            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 485        });
 486
 487        cx.spawn(async move |this, cx| {
 488            let result = restore.await;
 489            this.update(cx, |this, cx| {
 490                if let Err(err) = result.as_ref() {
 491                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 492                        message_id: checkpoint.message_id,
 493                        error: err.to_string(),
 494                    });
 495                } else {
 496                    this.truncate(checkpoint.message_id, cx);
 497                    this.last_restore_checkpoint = None;
 498                }
 499                this.pending_checkpoint = None;
 500                cx.emit(ThreadEvent::CheckpointChanged);
 501                cx.notify();
 502            })?;
 503            result
 504        })
 505    }
 506
 507    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 508        let pending_checkpoint = if self.is_generating() {
 509            return;
 510        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 511            checkpoint
 512        } else {
 513            return;
 514        };
 515
 516        let git_store = self.project.read(cx).git_store().clone();
 517        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 518        cx.spawn(async move |this, cx| match final_checkpoint.await {
 519            Ok(final_checkpoint) => {
 520                let equal = git_store
 521                    .update(cx, |store, cx| {
 522                        store.compare_checkpoints(
 523                            pending_checkpoint.git_checkpoint.clone(),
 524                            final_checkpoint.clone(),
 525                            cx,
 526                        )
 527                    })?
 528                    .await
 529                    .unwrap_or(false);
 530
 531                if equal {
 532                    git_store
 533                        .update(cx, |store, cx| {
 534                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 535                        })?
 536                        .detach();
 537                } else {
 538                    this.update(cx, |this, cx| {
 539                        this.insert_checkpoint(pending_checkpoint, cx)
 540                    })?;
 541                }
 542
 543                git_store
 544                    .update(cx, |store, cx| {
 545                        store.delete_checkpoint(final_checkpoint, cx)
 546                    })?
 547                    .detach();
 548
 549                Ok(())
 550            }
 551            Err(_) => this.update(cx, |this, cx| {
 552                this.insert_checkpoint(pending_checkpoint, cx)
 553            }),
 554        })
 555        .detach();
 556    }
 557
 558    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 559        self.checkpoints_by_message
 560            .insert(checkpoint.message_id, checkpoint);
 561        cx.emit(ThreadEvent::CheckpointChanged);
 562        cx.notify();
 563    }
 564
 565    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 566        self.last_restore_checkpoint.as_ref()
 567    }
 568
 569    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 570        let Some(message_ix) = self
 571            .messages
 572            .iter()
 573            .rposition(|message| message.id == message_id)
 574        else {
 575            return;
 576        };
 577        for deleted_message in self.messages.drain(message_ix..) {
 578            self.context_by_message.remove(&deleted_message.id);
 579            self.checkpoints_by_message.remove(&deleted_message.id);
 580        }
 581        cx.notify();
 582    }
 583
 584    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 585        self.context_by_message
 586            .get(&id)
 587            .into_iter()
 588            .flat_map(|context| {
 589                context
 590                    .iter()
 591                    .filter_map(|context_id| self.context.get(&context_id))
 592            })
 593    }
 594
 595    /// Returns whether all of the tool uses have finished running.
 596    pub fn all_tools_finished(&self) -> bool {
 597        // If the only pending tool uses left are the ones with errors, then
 598        // that means that we've finished running all of the pending tools.
 599        self.tool_use
 600            .pending_tool_uses()
 601            .iter()
 602            .all(|tool_use| tool_use.status.is_error())
 603    }
 604
 605    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 606        self.tool_use.tool_uses_for_message(id, cx)
 607    }
 608
 609    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 610        self.tool_use.tool_results_for_message(id)
 611    }
 612
 613    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 614        self.tool_use.tool_result(id)
 615    }
 616
 617    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 618        self.tool_use.message_has_tool_results(message_id)
 619    }
 620
 621    pub fn insert_user_message(
 622        &mut self,
 623        text: impl Into<String>,
 624        context: Vec<AssistantContext>,
 625        git_checkpoint: Option<GitStoreCheckpoint>,
 626        cx: &mut Context<Self>,
 627    ) -> MessageId {
 628        let text = text.into();
 629
 630        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 631
 632        // Filter out contexts that have already been included in previous messages
 633        let new_context: Vec<_> = context
 634            .into_iter()
 635            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 636            .collect();
 637
 638        if !new_context.is_empty() {
 639            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 640                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 641                    message.context = context_string;
 642                }
 643            }
 644
 645            self.action_log.update(cx, |log, cx| {
 646                // Track all buffers added as context
 647                for ctx in &new_context {
 648                    match ctx {
 649                        AssistantContext::File(file_ctx) => {
 650                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 651                        }
 652                        AssistantContext::Directory(dir_ctx) => {
 653                            for context_buffer in &dir_ctx.context_buffers {
 654                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 655                            }
 656                        }
 657                        AssistantContext::Symbol(symbol_ctx) => {
 658                            log.buffer_added_as_context(
 659                                symbol_ctx.context_symbol.buffer.clone(),
 660                                cx,
 661                            );
 662                        }
 663                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 664                    }
 665                }
 666            });
 667        }
 668
 669        let context_ids = new_context
 670            .iter()
 671            .map(|context| context.id())
 672            .collect::<Vec<_>>();
 673        self.context.extend(
 674            new_context
 675                .into_iter()
 676                .map(|context| (context.id(), context)),
 677        );
 678        self.context_by_message.insert(message_id, context_ids);
 679
 680        if let Some(git_checkpoint) = git_checkpoint {
 681            self.pending_checkpoint = Some(ThreadCheckpoint {
 682                message_id,
 683                git_checkpoint,
 684            });
 685        }
 686
 687        self.auto_capture_telemetry(cx);
 688
 689        message_id
 690    }
 691
 692    pub fn insert_message(
 693        &mut self,
 694        role: Role,
 695        segments: Vec<MessageSegment>,
 696        cx: &mut Context<Self>,
 697    ) -> MessageId {
 698        let id = self.next_message_id.post_inc();
 699        self.messages.push(Message {
 700            id,
 701            role,
 702            segments,
 703            context: String::new(),
 704        });
 705        self.touch_updated_at();
 706        cx.emit(ThreadEvent::MessageAdded(id));
 707        id
 708    }
 709
 710    pub fn edit_message(
 711        &mut self,
 712        id: MessageId,
 713        new_role: Role,
 714        new_segments: Vec<MessageSegment>,
 715        cx: &mut Context<Self>,
 716    ) -> bool {
 717        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 718            return false;
 719        };
 720        message.role = new_role;
 721        message.segments = new_segments;
 722        self.touch_updated_at();
 723        cx.emit(ThreadEvent::MessageEdited(id));
 724        true
 725    }
 726
 727    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 728        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 729            return false;
 730        };
 731        self.messages.remove(index);
 732        self.context_by_message.remove(&id);
 733        self.touch_updated_at();
 734        cx.emit(ThreadEvent::MessageDeleted(id));
 735        true
 736    }
 737
 738    /// Returns the representation of this [`Thread`] in a textual form.
 739    ///
 740    /// This is the representation we use when attaching a thread as context to another thread.
 741    pub fn text(&self) -> String {
 742        let mut text = String::new();
 743
 744        for message in &self.messages {
 745            text.push_str(match message.role {
 746                language_model::Role::User => "User:",
 747                language_model::Role::Assistant => "Assistant:",
 748                language_model::Role::System => "System:",
 749            });
 750            text.push('\n');
 751
 752            for segment in &message.segments {
 753                match segment {
 754                    MessageSegment::Text(content) => text.push_str(content),
 755                    MessageSegment::Thinking(content) => {
 756                        text.push_str(&format!("<think>{}</think>", content))
 757                    }
 758                }
 759            }
 760            text.push('\n');
 761        }
 762
 763        text
 764    }
 765
 766    /// Serializes this thread into a format for storage or telemetry.
 767    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 768        let initial_project_snapshot = self.initial_project_snapshot.clone();
 769        cx.spawn(async move |this, cx| {
 770            let initial_project_snapshot = initial_project_snapshot.await;
 771            this.read_with(cx, |this, cx| SerializedThread {
 772                version: SerializedThread::VERSION.to_string(),
 773                summary: this.summary_or_default(),
 774                updated_at: this.updated_at(),
 775                messages: this
 776                    .messages()
 777                    .map(|message| SerializedMessage {
 778                        id: message.id,
 779                        role: message.role,
 780                        segments: message
 781                            .segments
 782                            .iter()
 783                            .map(|segment| match segment {
 784                                MessageSegment::Text(text) => {
 785                                    SerializedMessageSegment::Text { text: text.clone() }
 786                                }
 787                                MessageSegment::Thinking(text) => {
 788                                    SerializedMessageSegment::Thinking { text: text.clone() }
 789                                }
 790                            })
 791                            .collect(),
 792                        tool_uses: this
 793                            .tool_uses_for_message(message.id, cx)
 794                            .into_iter()
 795                            .map(|tool_use| SerializedToolUse {
 796                                id: tool_use.id,
 797                                name: tool_use.name,
 798                                input: tool_use.input,
 799                            })
 800                            .collect(),
 801                        tool_results: this
 802                            .tool_results_for_message(message.id)
 803                            .into_iter()
 804                            .map(|tool_result| SerializedToolResult {
 805                                tool_use_id: tool_result.tool_use_id.clone(),
 806                                is_error: tool_result.is_error,
 807                                content: tool_result.content.clone(),
 808                            })
 809                            .collect(),
 810                        context: message.context.clone(),
 811                    })
 812                    .collect(),
 813                initial_project_snapshot,
 814                cumulative_token_usage: this.cumulative_token_usage.clone(),
 815                detailed_summary_state: this.detailed_summary_state.clone(),
 816            })
 817        })
 818    }
 819
 820    pub fn send_to_model(
 821        &mut self,
 822        model: Arc<dyn LanguageModel>,
 823        request_kind: RequestKind,
 824        cx: &mut Context<Self>,
 825    ) {
 826        let mut request = self.to_completion_request(request_kind, cx);
 827        if model.supports_tools() {
 828            request.tools = {
 829                let mut tools = Vec::new();
 830                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 831                    LanguageModelRequestTool {
 832                        name: tool.name(),
 833                        description: tool.description(),
 834                        input_schema: tool.input_schema(model.tool_input_format()),
 835                    }
 836                }));
 837
 838                tools
 839            };
 840        }
 841
 842        self.stream_completion(request, model, cx);
 843    }
 844
 845    pub fn used_tools_since_last_user_message(&self) -> bool {
 846        for message in self.messages.iter().rev() {
 847            if self.tool_use.message_has_tool_results(message.id) {
 848                return true;
 849            } else if message.role == Role::User {
 850                return false;
 851            }
 852        }
 853
 854        false
 855    }
 856
 857    pub fn to_completion_request(
 858        &self,
 859        request_kind: RequestKind,
 860        cx: &App,
 861    ) -> LanguageModelRequest {
 862        let mut request = LanguageModelRequest {
 863            messages: vec![],
 864            tools: Vec::new(),
 865            stop: Vec::new(),
 866            temperature: None,
 867        };
 868
 869        if let Some(project_context) = self.project_context.borrow().as_ref() {
 870            if let Some(system_prompt) = self
 871                .prompt_builder
 872                .generate_assistant_system_prompt(project_context)
 873                .context("failed to generate assistant system prompt")
 874                .log_err()
 875            {
 876                request.messages.push(LanguageModelRequestMessage {
 877                    role: Role::System,
 878                    content: vec![MessageContent::Text(system_prompt)],
 879                    cache: true,
 880                });
 881            }
 882        } else {
 883            log::error!("project_context not set.")
 884        }
 885
 886        for message in &self.messages {
 887            let mut request_message = LanguageModelRequestMessage {
 888                role: message.role,
 889                content: Vec::new(),
 890                cache: false,
 891            };
 892
 893            match request_kind {
 894                RequestKind::Chat => {
 895                    self.tool_use
 896                        .attach_tool_results(message.id, &mut request_message);
 897                }
 898                RequestKind::Summarize => {
 899                    // We don't care about tool use during summarization.
 900                    if self.tool_use.message_has_tool_results(message.id) {
 901                        continue;
 902                    }
 903                }
 904            }
 905
 906            if !message.segments.is_empty() {
 907                request_message
 908                    .content
 909                    .push(MessageContent::Text(message.to_string()));
 910            }
 911
 912            match request_kind {
 913                RequestKind::Chat => {
 914                    self.tool_use
 915                        .attach_tool_uses(message.id, &mut request_message);
 916                }
 917                RequestKind::Summarize => {
 918                    // We don't care about tool use during summarization.
 919                }
 920            };
 921
 922            request.messages.push(request_message);
 923        }
 924
 925        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 926        if let Some(last) = request.messages.last_mut() {
 927            last.cache = true;
 928        }
 929
 930        self.attached_tracked_files_state(&mut request.messages, cx);
 931
 932        request
 933    }
 934
 935    fn attached_tracked_files_state(
 936        &self,
 937        messages: &mut Vec<LanguageModelRequestMessage>,
 938        cx: &App,
 939    ) {
 940        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 941
 942        let mut stale_message = String::new();
 943
 944        let action_log = self.action_log.read(cx);
 945
 946        for stale_file in action_log.stale_buffers(cx) {
 947            let Some(file) = stale_file.read(cx).file() else {
 948                continue;
 949            };
 950
 951            if stale_message.is_empty() {
 952                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
 953            }
 954
 955            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 956        }
 957
 958        let mut content = Vec::with_capacity(2);
 959
 960        if !stale_message.is_empty() {
 961            content.push(stale_message.into());
 962        }
 963
 964        if action_log.has_edited_files_since_project_diagnostics_check() {
 965            content.push(
 966                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 967                and fix all errors AND warnings you introduced! \
 968                DO NOT mention you're going to do this until you're done."
 969                    .into(),
 970            );
 971        }
 972
 973        if !content.is_empty() {
 974            let context_message = LanguageModelRequestMessage {
 975                role: Role::User,
 976                content,
 977                cache: false,
 978            };
 979
 980            messages.push(context_message);
 981        }
 982    }
 983
 984    pub fn stream_completion(
 985        &mut self,
 986        request: LanguageModelRequest,
 987        model: Arc<dyn LanguageModel>,
 988        cx: &mut Context<Self>,
 989    ) {
 990        let pending_completion_id = post_inc(&mut self.completion_count);
 991
 992        let task = cx.spawn(async move |thread, cx| {
 993            let stream = model.stream_completion(request, &cx);
 994            let initial_token_usage =
 995                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
 996            let stream_completion = async {
 997                let mut events = stream.await?;
 998                let mut stop_reason = StopReason::EndTurn;
 999                let mut current_token_usage = TokenUsage::default();
1000
1001                while let Some(event) = events.next().await {
1002                    let event = event?;
1003
1004                    thread.update(cx, |thread, cx| {
1005                        match event {
1006                            LanguageModelCompletionEvent::StartMessage { .. } => {
1007                                thread.insert_message(
1008                                    Role::Assistant,
1009                                    vec![MessageSegment::Text(String::new())],
1010                                    cx,
1011                                );
1012                            }
1013                            LanguageModelCompletionEvent::Stop(reason) => {
1014                                stop_reason = reason;
1015                            }
1016                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1017                                thread.cumulative_token_usage =
1018                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1019                                        - current_token_usage.clone();
1020                                current_token_usage = token_usage;
1021                            }
1022                            LanguageModelCompletionEvent::Text(chunk) => {
1023                                if let Some(last_message) = thread.messages.last_mut() {
1024                                    if last_message.role == Role::Assistant {
1025                                        last_message.push_text(&chunk);
1026                                        cx.emit(ThreadEvent::StreamedAssistantText(
1027                                            last_message.id,
1028                                            chunk,
1029                                        ));
1030                                    } else {
1031                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1032                                        // of a new Assistant response.
1033                                        //
1034                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1035                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1036                                        thread.insert_message(
1037                                            Role::Assistant,
1038                                            vec![MessageSegment::Text(chunk.to_string())],
1039                                            cx,
1040                                        );
1041                                    };
1042                                }
1043                            }
1044                            LanguageModelCompletionEvent::Thinking(chunk) => {
1045                                if let Some(last_message) = thread.messages.last_mut() {
1046                                    if last_message.role == Role::Assistant {
1047                                        last_message.push_thinking(&chunk);
1048                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1049                                            last_message.id,
1050                                            chunk,
1051                                        ));
1052                                    } else {
1053                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1054                                        // of a new Assistant response.
1055                                        //
1056                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1057                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1058                                        thread.insert_message(
1059                                            Role::Assistant,
1060                                            vec![MessageSegment::Thinking(chunk.to_string())],
1061                                            cx,
1062                                        );
1063                                    };
1064                                }
1065                            }
1066                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1067                                let last_assistant_message_id = thread
1068                                    .messages
1069                                    .iter_mut()
1070                                    .rfind(|message| message.role == Role::Assistant)
1071                                    .map(|message| message.id)
1072                                    .unwrap_or_else(|| {
1073                                        thread.insert_message(Role::Assistant, vec![], cx)
1074                                    });
1075
1076                                thread.tool_use.request_tool_use(
1077                                    last_assistant_message_id,
1078                                    tool_use,
1079                                    cx,
1080                                );
1081                            }
1082                        }
1083
1084                        thread.touch_updated_at();
1085                        cx.emit(ThreadEvent::StreamedCompletion);
1086                        cx.notify();
1087
1088                        thread.auto_capture_telemetry(cx);
1089                    })?;
1090
1091                    smol::future::yield_now().await;
1092                }
1093
1094                thread.update(cx, |thread, cx| {
1095                    thread
1096                        .pending_completions
1097                        .retain(|completion| completion.id != pending_completion_id);
1098
1099                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1100                        thread.summarize(cx);
1101                    }
1102                })?;
1103
1104                anyhow::Ok(stop_reason)
1105            };
1106
1107            let result = stream_completion.await;
1108
1109            thread
1110                .update(cx, |thread, cx| {
1111                    thread.finalize_pending_checkpoint(cx);
1112                    match result.as_ref() {
1113                        Ok(stop_reason) => match stop_reason {
1114                            StopReason::ToolUse => {
1115                                let tool_uses = thread.use_pending_tools(cx);
1116                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1117                            }
1118                            StopReason::EndTurn => {}
1119                            StopReason::MaxTokens => {}
1120                        },
1121                        Err(error) => {
1122                            if error.is::<PaymentRequiredError>() {
1123                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1124                            } else if error.is::<MaxMonthlySpendReachedError>() {
1125                                cx.emit(ThreadEvent::ShowError(
1126                                    ThreadError::MaxMonthlySpendReached,
1127                                ));
1128                            } else {
1129                                let error_message = error
1130                                    .chain()
1131                                    .map(|err| err.to_string())
1132                                    .collect::<Vec<_>>()
1133                                    .join("\n");
1134                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1135                                    header: "Error interacting with language model".into(),
1136                                    message: SharedString::from(error_message.clone()),
1137                                }));
1138                            }
1139
1140                            thread.cancel_last_completion(cx);
1141                        }
1142                    }
1143                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1144
1145                    thread.auto_capture_telemetry(cx);
1146
1147                    if let Ok(initial_usage) = initial_token_usage {
1148                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1149
1150                        telemetry::event!(
1151                            "Assistant Thread Completion",
1152                            thread_id = thread.id().to_string(),
1153                            model = model.telemetry_id(),
1154                            model_provider = model.provider_id().to_string(),
1155                            input_tokens = usage.input_tokens,
1156                            output_tokens = usage.output_tokens,
1157                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1158                            cache_read_input_tokens = usage.cache_read_input_tokens,
1159                        );
1160                    }
1161                })
1162                .ok();
1163        });
1164
1165        self.pending_completions.push(PendingCompletion {
1166            id: pending_completion_id,
1167            _task: task,
1168        });
1169    }
1170
1171    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1172        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1173            return;
1174        };
1175
1176        if !model.provider.is_authenticated(cx) {
1177            return;
1178        }
1179
1180        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1181        request.messages.push(LanguageModelRequestMessage {
1182            role: Role::User,
1183            content: vec![
1184                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1185                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1186                 If the conversation is about a specific subject, include it in the title. \
1187                 Be descriptive. DO NOT speak in the first person."
1188                    .into(),
1189            ],
1190            cache: false,
1191        });
1192
1193        self.pending_summary = cx.spawn(async move |this, cx| {
1194            async move {
1195                let stream = model.model.stream_completion_text(request, &cx);
1196                let mut messages = stream.await?;
1197
1198                let mut new_summary = String::new();
1199                while let Some(message) = messages.stream.next().await {
1200                    let text = message?;
1201                    let mut lines = text.lines();
1202                    new_summary.extend(lines.next());
1203
1204                    // Stop if the LLM generated multiple lines.
1205                    if lines.next().is_some() {
1206                        break;
1207                    }
1208                }
1209
1210                this.update(cx, |this, cx| {
1211                    if !new_summary.is_empty() {
1212                        this.summary = Some(new_summary.into());
1213                    }
1214
1215                    cx.emit(ThreadEvent::SummaryGenerated);
1216                })?;
1217
1218                anyhow::Ok(())
1219            }
1220            .log_err()
1221            .await
1222        });
1223    }
1224
1225    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1226        let last_message_id = self.messages.last().map(|message| message.id)?;
1227
1228        match &self.detailed_summary_state {
1229            DetailedSummaryState::Generating { message_id, .. }
1230            | DetailedSummaryState::Generated { message_id, .. }
1231                if *message_id == last_message_id =>
1232            {
1233                // Already up-to-date
1234                return None;
1235            }
1236            _ => {}
1237        }
1238
1239        let ConfiguredModel { model, provider } =
1240            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1241
1242        if !provider.is_authenticated(cx) {
1243            return None;
1244        }
1245
1246        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1247
1248        request.messages.push(LanguageModelRequestMessage {
1249            role: Role::User,
1250            content: vec![
1251                "Generate a detailed summary of this conversation. Include:\n\
1252                1. A brief overview of what was discussed\n\
1253                2. Key facts or information discovered\n\
1254                3. Outcomes or conclusions reached\n\
1255                4. Any action items or next steps if any\n\
1256                Format it in Markdown with headings and bullet points."
1257                    .into(),
1258            ],
1259            cache: false,
1260        });
1261
1262        let task = cx.spawn(async move |thread, cx| {
1263            let stream = model.stream_completion_text(request, &cx);
1264            let Some(mut messages) = stream.await.log_err() else {
1265                thread
1266                    .update(cx, |this, _cx| {
1267                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1268                    })
1269                    .log_err();
1270
1271                return;
1272            };
1273
1274            let mut new_detailed_summary = String::new();
1275
1276            while let Some(chunk) = messages.stream.next().await {
1277                if let Some(chunk) = chunk.log_err() {
1278                    new_detailed_summary.push_str(&chunk);
1279                }
1280            }
1281
1282            thread
1283                .update(cx, |this, _cx| {
1284                    this.detailed_summary_state = DetailedSummaryState::Generated {
1285                        text: new_detailed_summary.into(),
1286                        message_id: last_message_id,
1287                    };
1288                })
1289                .log_err();
1290        });
1291
1292        self.detailed_summary_state = DetailedSummaryState::Generating {
1293            message_id: last_message_id,
1294        };
1295
1296        Some(task)
1297    }
1298
1299    pub fn is_generating_detailed_summary(&self) -> bool {
1300        matches!(
1301            self.detailed_summary_state,
1302            DetailedSummaryState::Generating { .. }
1303        )
1304    }
1305
1306    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1307        self.auto_capture_telemetry(cx);
1308        let request = self.to_completion_request(RequestKind::Chat, cx);
1309        let messages = Arc::new(request.messages);
1310        let pending_tool_uses = self
1311            .tool_use
1312            .pending_tool_uses()
1313            .into_iter()
1314            .filter(|tool_use| tool_use.status.is_idle())
1315            .cloned()
1316            .collect::<Vec<_>>();
1317
1318        for tool_use in pending_tool_uses.iter() {
1319            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1320                if tool.needs_confirmation(&tool_use.input, cx)
1321                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1322                {
1323                    self.tool_use.confirm_tool_use(
1324                        tool_use.id.clone(),
1325                        tool_use.ui_text.clone(),
1326                        tool_use.input.clone(),
1327                        messages.clone(),
1328                        tool,
1329                    );
1330                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1331                } else {
1332                    self.run_tool(
1333                        tool_use.id.clone(),
1334                        tool_use.ui_text.clone(),
1335                        tool_use.input.clone(),
1336                        &messages,
1337                        tool,
1338                        cx,
1339                    );
1340                }
1341            }
1342        }
1343
1344        pending_tool_uses
1345    }
1346
1347    pub fn run_tool(
1348        &mut self,
1349        tool_use_id: LanguageModelToolUseId,
1350        ui_text: impl Into<SharedString>,
1351        input: serde_json::Value,
1352        messages: &[LanguageModelRequestMessage],
1353        tool: Arc<dyn Tool>,
1354        cx: &mut Context<Thread>,
1355    ) {
1356        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1357        self.tool_use
1358            .run_pending_tool(tool_use_id, ui_text.into(), task);
1359    }
1360
1361    fn spawn_tool_use(
1362        &mut self,
1363        tool_use_id: LanguageModelToolUseId,
1364        messages: &[LanguageModelRequestMessage],
1365        input: serde_json::Value,
1366        tool: Arc<dyn Tool>,
1367        cx: &mut Context<Thread>,
1368    ) -> Task<()> {
1369        let tool_name: Arc<str> = tool.name().into();
1370
1371        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1372            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1373        } else {
1374            tool.run(
1375                input,
1376                messages,
1377                self.project.clone(),
1378                self.action_log.clone(),
1379                cx,
1380            )
1381        };
1382
1383        cx.spawn({
1384            async move |thread: WeakEntity<Thread>, cx| {
1385                let output = run_tool.await;
1386
1387                thread
1388                    .update(cx, |thread, cx| {
1389                        let pending_tool_use = thread.tool_use.insert_tool_output(
1390                            tool_use_id.clone(),
1391                            tool_name,
1392                            output,
1393                            cx,
1394                        );
1395                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1396                    })
1397                    .ok();
1398            }
1399        })
1400    }
1401
1402    fn tool_finished(
1403        &mut self,
1404        tool_use_id: LanguageModelToolUseId,
1405        pending_tool_use: Option<PendingToolUse>,
1406        canceled: bool,
1407        cx: &mut Context<Self>,
1408    ) {
1409        if self.all_tools_finished() {
1410            let model_registry = LanguageModelRegistry::read_global(cx);
1411            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1412                self.attach_tool_results(cx);
1413                if !canceled {
1414                    self.send_to_model(model, RequestKind::Chat, cx);
1415                }
1416            }
1417        }
1418
1419        cx.emit(ThreadEvent::ToolFinished {
1420            tool_use_id,
1421            pending_tool_use,
1422        });
1423    }
1424
1425    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1426        // Insert a user message to contain the tool results.
1427        self.insert_user_message(
1428            // TODO: Sending up a user message without any content results in the model sending back
1429            // responses that also don't have any content. We currently don't handle this case well,
1430            // so for now we provide some text to keep the model on track.
1431            "Here are the tool results.",
1432            Vec::new(),
1433            None,
1434            cx,
1435        );
1436    }
1437
1438    /// Cancels the last pending completion, if there are any pending.
1439    ///
1440    /// Returns whether a completion was canceled.
1441    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1442        let canceled = if self.pending_completions.pop().is_some() {
1443            true
1444        } else {
1445            let mut canceled = false;
1446            for pending_tool_use in self.tool_use.cancel_pending() {
1447                canceled = true;
1448                self.tool_finished(
1449                    pending_tool_use.id.clone(),
1450                    Some(pending_tool_use),
1451                    true,
1452                    cx,
1453                );
1454            }
1455            canceled
1456        };
1457        self.finalize_pending_checkpoint(cx);
1458        canceled
1459    }
1460
1461    pub fn feedback(&self) -> Option<ThreadFeedback> {
1462        self.feedback
1463    }
1464
1465    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1466        self.message_feedback.get(&message_id).copied()
1467    }
1468
1469    pub fn report_message_feedback(
1470        &mut self,
1471        message_id: MessageId,
1472        feedback: ThreadFeedback,
1473        cx: &mut Context<Self>,
1474    ) -> Task<Result<()>> {
1475        if self.message_feedback.get(&message_id) == Some(&feedback) {
1476            return Task::ready(Ok(()));
1477        }
1478
1479        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1480        let serialized_thread = self.serialize(cx);
1481        let thread_id = self.id().clone();
1482        let client = self.project.read(cx).client();
1483
1484        let enabled_tool_names: Vec<String> = self
1485            .tools()
1486            .enabled_tools(cx)
1487            .iter()
1488            .map(|tool| tool.name().to_string())
1489            .collect();
1490
1491        self.message_feedback.insert(message_id, feedback);
1492
1493        cx.notify();
1494
1495        let message_content = self
1496            .message(message_id)
1497            .map(|msg| msg.to_string())
1498            .unwrap_or_default();
1499
1500        cx.background_spawn(async move {
1501            let final_project_snapshot = final_project_snapshot.await;
1502            let serialized_thread = serialized_thread.await?;
1503            let thread_data =
1504                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1505
1506            let rating = match feedback {
1507                ThreadFeedback::Positive => "positive",
1508                ThreadFeedback::Negative => "negative",
1509            };
1510            telemetry::event!(
1511                "Assistant Thread Rated",
1512                rating,
1513                thread_id,
1514                enabled_tool_names,
1515                message_id = message_id.0,
1516                message_content,
1517                thread_data,
1518                final_project_snapshot
1519            );
1520            client.telemetry().flush_events();
1521
1522            Ok(())
1523        })
1524    }
1525
1526    pub fn report_feedback(
1527        &mut self,
1528        feedback: ThreadFeedback,
1529        cx: &mut Context<Self>,
1530    ) -> Task<Result<()>> {
1531        let last_assistant_message_id = self
1532            .messages
1533            .iter()
1534            .rev()
1535            .find(|msg| msg.role == Role::Assistant)
1536            .map(|msg| msg.id);
1537
1538        if let Some(message_id) = last_assistant_message_id {
1539            self.report_message_feedback(message_id, feedback, cx)
1540        } else {
1541            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1542            let serialized_thread = self.serialize(cx);
1543            let thread_id = self.id().clone();
1544            let client = self.project.read(cx).client();
1545            self.feedback = Some(feedback);
1546            cx.notify();
1547
1548            cx.background_spawn(async move {
1549                let final_project_snapshot = final_project_snapshot.await;
1550                let serialized_thread = serialized_thread.await?;
1551                let thread_data = serde_json::to_value(serialized_thread)
1552                    .unwrap_or_else(|_| serde_json::Value::Null);
1553
1554                let rating = match feedback {
1555                    ThreadFeedback::Positive => "positive",
1556                    ThreadFeedback::Negative => "negative",
1557                };
1558                telemetry::event!(
1559                    "Assistant Thread Rated",
1560                    rating,
1561                    thread_id,
1562                    thread_data,
1563                    final_project_snapshot
1564                );
1565                client.telemetry().flush_events();
1566
1567                Ok(())
1568            })
1569        }
1570    }
1571
1572    /// Create a snapshot of the current project state including git information and unsaved buffers.
1573    fn project_snapshot(
1574        project: Entity<Project>,
1575        cx: &mut Context<Self>,
1576    ) -> Task<Arc<ProjectSnapshot>> {
1577        let git_store = project.read(cx).git_store().clone();
1578        let worktree_snapshots: Vec<_> = project
1579            .read(cx)
1580            .visible_worktrees(cx)
1581            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1582            .collect();
1583
1584        cx.spawn(async move |_, cx| {
1585            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1586
1587            let mut unsaved_buffers = Vec::new();
1588            cx.update(|app_cx| {
1589                let buffer_store = project.read(app_cx).buffer_store();
1590                for buffer_handle in buffer_store.read(app_cx).buffers() {
1591                    let buffer = buffer_handle.read(app_cx);
1592                    if buffer.is_dirty() {
1593                        if let Some(file) = buffer.file() {
1594                            let path = file.path().to_string_lossy().to_string();
1595                            unsaved_buffers.push(path);
1596                        }
1597                    }
1598                }
1599            })
1600            .ok();
1601
1602            Arc::new(ProjectSnapshot {
1603                worktree_snapshots,
1604                unsaved_buffer_paths: unsaved_buffers,
1605                timestamp: Utc::now(),
1606            })
1607        })
1608    }
1609
1610    fn worktree_snapshot(
1611        worktree: Entity<project::Worktree>,
1612        git_store: Entity<GitStore>,
1613        cx: &App,
1614    ) -> Task<WorktreeSnapshot> {
1615        cx.spawn(async move |cx| {
1616            // Get worktree path and snapshot
1617            let worktree_info = cx.update(|app_cx| {
1618                let worktree = worktree.read(app_cx);
1619                let path = worktree.abs_path().to_string_lossy().to_string();
1620                let snapshot = worktree.snapshot();
1621                (path, snapshot)
1622            });
1623
1624            let Ok((worktree_path, _snapshot)) = worktree_info else {
1625                return WorktreeSnapshot {
1626                    worktree_path: String::new(),
1627                    git_state: None,
1628                };
1629            };
1630
1631            let git_state = git_store
1632                .update(cx, |git_store, cx| {
1633                    git_store
1634                        .repositories()
1635                        .values()
1636                        .find(|repo| {
1637                            repo.read(cx)
1638                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1639                                .is_some()
1640                        })
1641                        .cloned()
1642                })
1643                .ok()
1644                .flatten()
1645                .map(|repo| {
1646                    repo.update(cx, |repo, _| {
1647                        let current_branch =
1648                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1649                        repo.send_job(None, |state, _| async move {
1650                            let RepositoryState::Local { backend, .. } = state else {
1651                                return GitState {
1652                                    remote_url: None,
1653                                    head_sha: None,
1654                                    current_branch,
1655                                    diff: None,
1656                                };
1657                            };
1658
1659                            let remote_url = backend.remote_url("origin");
1660                            let head_sha = backend.head_sha();
1661                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1662
1663                            GitState {
1664                                remote_url,
1665                                head_sha,
1666                                current_branch,
1667                                diff,
1668                            }
1669                        })
1670                    })
1671                });
1672
1673            let git_state = match git_state {
1674                Some(git_state) => match git_state.ok() {
1675                    Some(git_state) => git_state.await.ok(),
1676                    None => None,
1677                },
1678                None => None,
1679            };
1680
1681            WorktreeSnapshot {
1682                worktree_path,
1683                git_state,
1684            }
1685        })
1686    }
1687
1688    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1689        let mut markdown = Vec::new();
1690
1691        if let Some(summary) = self.summary() {
1692            writeln!(markdown, "# {summary}\n")?;
1693        };
1694
1695        for message in self.messages() {
1696            writeln!(
1697                markdown,
1698                "## {role}\n",
1699                role = match message.role {
1700                    Role::User => "User",
1701                    Role::Assistant => "Assistant",
1702                    Role::System => "System",
1703                }
1704            )?;
1705
1706            if !message.context.is_empty() {
1707                writeln!(markdown, "{}", message.context)?;
1708            }
1709
1710            for segment in &message.segments {
1711                match segment {
1712                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1713                    MessageSegment::Thinking(text) => {
1714                        writeln!(markdown, "<think>{}</think>\n", text)?
1715                    }
1716                }
1717            }
1718
1719            for tool_use in self.tool_uses_for_message(message.id, cx) {
1720                writeln!(
1721                    markdown,
1722                    "**Use Tool: {} ({})**",
1723                    tool_use.name, tool_use.id
1724                )?;
1725                writeln!(markdown, "```json")?;
1726                writeln!(
1727                    markdown,
1728                    "{}",
1729                    serde_json::to_string_pretty(&tool_use.input)?
1730                )?;
1731                writeln!(markdown, "```")?;
1732            }
1733
1734            for tool_result in self.tool_results_for_message(message.id) {
1735                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1736                if tool_result.is_error {
1737                    write!(markdown, " (Error)")?;
1738                }
1739
1740                writeln!(markdown, "**\n")?;
1741                writeln!(markdown, "{}", tool_result.content)?;
1742            }
1743        }
1744
1745        Ok(String::from_utf8_lossy(&markdown).to_string())
1746    }
1747
1748    pub fn keep_edits_in_range(
1749        &mut self,
1750        buffer: Entity<language::Buffer>,
1751        buffer_range: Range<language::Anchor>,
1752        cx: &mut Context<Self>,
1753    ) {
1754        self.action_log.update(cx, |action_log, cx| {
1755            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1756        });
1757    }
1758
1759    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1760        self.action_log
1761            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1762    }
1763
1764    pub fn reject_edits_in_range(
1765        &mut self,
1766        buffer: Entity<language::Buffer>,
1767        buffer_range: Range<language::Anchor>,
1768        cx: &mut Context<Self>,
1769    ) -> Task<Result<()>> {
1770        self.action_log.update(cx, |action_log, cx| {
1771            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1772        })
1773    }
1774
1775    pub fn action_log(&self) -> &Entity<ActionLog> {
1776        &self.action_log
1777    }
1778
1779    pub fn project(&self) -> &Entity<Project> {
1780        &self.project
1781    }
1782
1783    pub fn cumulative_token_usage(&self) -> TokenUsage {
1784        self.cumulative_token_usage.clone()
1785    }
1786
1787    pub fn auto_capture_telemetry(&self, cx: &mut Context<Self>) {
1788        static mut LAST_CAPTURE: Option<std::time::Instant> = None;
1789        let now = std::time::Instant::now();
1790        let should_check = unsafe {
1791            if let Some(last) = LAST_CAPTURE {
1792                if now.duration_since(last).as_secs() < 10 {
1793                    return;
1794                }
1795            }
1796            LAST_CAPTURE = Some(now);
1797            true
1798        };
1799
1800        if !should_check {
1801            return;
1802        }
1803
1804        let feature_flag_enabled = cx.has_flag::<feature_flags::ThreadAutoCapture>();
1805
1806        if cfg!(debug_assertions) {
1807            if !feature_flag_enabled {
1808                return;
1809            }
1810        }
1811
1812        let thread_id = self.id().clone();
1813
1814        let github_handle = self
1815            .project
1816            .read(cx)
1817            .user_store()
1818            .read(cx)
1819            .current_user()
1820            .map(|user| user.github_login.clone());
1821
1822        let client = self.project.read(cx).client().clone();
1823
1824        let serialized_thread = self.serialize(cx);
1825
1826        cx.foreground_executor()
1827            .spawn(async move {
1828                if let Ok(serialized_thread) = serialized_thread.await {
1829                    let thread_data = serde_json::to_value(serialized_thread)
1830                        .unwrap_or_else(|_| serde_json::Value::Null);
1831
1832                    telemetry::event!(
1833                        "Agent Thread AutoCaptured",
1834                        thread_id = thread_id.to_string(),
1835                        thread_data = thread_data,
1836                        auto_capture_reason = "tracked_user",
1837                        github_handle = github_handle
1838                    );
1839
1840                    client.telemetry().flush_events();
1841                }
1842            })
1843            .detach();
1844    }
1845
1846    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1847        let model_registry = LanguageModelRegistry::read_global(cx);
1848        let Some(model) = model_registry.default_model() else {
1849            return TotalTokenUsage::default();
1850        };
1851
1852        let max = model.model.max_token_count();
1853
1854        #[cfg(debug_assertions)]
1855        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1856            .unwrap_or("0.8".to_string())
1857            .parse()
1858            .unwrap();
1859        #[cfg(not(debug_assertions))]
1860        let warning_threshold: f32 = 0.8;
1861
1862        let total = self.cumulative_token_usage.total_tokens() as usize;
1863
1864        let ratio = if total >= max {
1865            TokenUsageRatio::Exceeded
1866        } else if total as f32 / max as f32 >= warning_threshold {
1867            TokenUsageRatio::Warning
1868        } else {
1869            TokenUsageRatio::Normal
1870        };
1871
1872        TotalTokenUsage { total, max, ratio }
1873    }
1874
1875    pub fn deny_tool_use(
1876        &mut self,
1877        tool_use_id: LanguageModelToolUseId,
1878        tool_name: Arc<str>,
1879        cx: &mut Context<Self>,
1880    ) {
1881        let err = Err(anyhow::anyhow!(
1882            "Permission to run tool action denied by user"
1883        ));
1884
1885        self.tool_use
1886            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1887        self.tool_finished(tool_use_id.clone(), None, true, cx);
1888    }
1889}
1890
1891#[derive(Debug, Clone, Error)]
1892pub enum ThreadError {
1893    #[error("Payment required")]
1894    PaymentRequired,
1895    #[error("Max monthly spend reached")]
1896    MaxMonthlySpendReached,
1897    #[error("Message {header}: {message}")]
1898    Message {
1899        header: SharedString,
1900        message: SharedString,
1901    },
1902}
1903
1904#[derive(Debug, Clone)]
1905pub enum ThreadEvent {
1906    ShowError(ThreadError),
1907    StreamedCompletion,
1908    StreamedAssistantText(MessageId, String),
1909    StreamedAssistantThinking(MessageId, String),
1910    Stopped(Result<StopReason, Arc<anyhow::Error>>),
1911    MessageAdded(MessageId),
1912    MessageEdited(MessageId),
1913    MessageDeleted(MessageId),
1914    SummaryGenerated,
1915    SummaryChanged,
1916    UsePendingTools {
1917        tool_uses: Vec<PendingToolUse>,
1918    },
1919    ToolFinished {
1920        #[allow(unused)]
1921        tool_use_id: LanguageModelToolUseId,
1922        /// The pending tool use that corresponds to this tool.
1923        pending_tool_use: Option<PendingToolUse>,
1924    },
1925    CheckpointChanged,
1926    ToolConfirmationNeeded,
1927}
1928
1929impl EventEmitter<ThreadEvent> for Thread {}
1930
1931struct PendingCompletion {
1932    id: usize,
1933    _task: Task<()>,
1934}
1935
1936#[cfg(test)]
1937mod tests {
1938    use super::*;
1939    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1940    use assistant_settings::AssistantSettings;
1941    use context_server::ContextServerSettings;
1942    use editor::EditorSettings;
1943    use gpui::TestAppContext;
1944    use project::{FakeFs, Project};
1945    use prompt_store::PromptBuilder;
1946    use serde_json::json;
1947    use settings::{Settings, SettingsStore};
1948    use std::sync::Arc;
1949    use theme::ThemeSettings;
1950    use util::path;
1951    use workspace::Workspace;
1952
1953    #[gpui::test]
1954    async fn test_message_with_context(cx: &mut TestAppContext) {
1955        init_test_settings(cx);
1956
1957        let project = create_test_project(
1958            cx,
1959            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1960        )
1961        .await;
1962
1963        let (_workspace, _thread_store, thread, context_store) =
1964            setup_test_environment(cx, project.clone()).await;
1965
1966        add_file_to_context(&project, &context_store, "test/code.rs", cx)
1967            .await
1968            .unwrap();
1969
1970        let context =
1971            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1972
1973        // Insert user message with context
1974        let message_id = thread.update(cx, |thread, cx| {
1975            thread.insert_user_message("Please explain this code", vec![context], None, cx)
1976        });
1977
1978        // Check content and context in message object
1979        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1980
1981        // Use different path format strings based on platform for the test
1982        #[cfg(windows)]
1983        let path_part = r"test\code.rs";
1984        #[cfg(not(windows))]
1985        let path_part = "test/code.rs";
1986
1987        let expected_context = format!(
1988            r#"
1989<context>
1990The following items were attached by the user. You don't need to use other tools to read them.
1991
1992<files>
1993```rs {path_part}
1994fn main() {{
1995    println!("Hello, world!");
1996}}
1997```
1998</files>
1999</context>
2000"#
2001        );
2002
2003        assert_eq!(message.role, Role::User);
2004        assert_eq!(message.segments.len(), 1);
2005        assert_eq!(
2006            message.segments[0],
2007            MessageSegment::Text("Please explain this code".to_string())
2008        );
2009        assert_eq!(message.context, expected_context);
2010
2011        // Check message in request
2012        let request = thread.read_with(cx, |thread, cx| {
2013            thread.to_completion_request(RequestKind::Chat, cx)
2014        });
2015
2016        assert_eq!(request.messages.len(), 2);
2017        let expected_full_message = format!("{}Please explain this code", expected_context);
2018        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2019    }
2020
2021    #[gpui::test]
2022    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2023        init_test_settings(cx);
2024
2025        let project = create_test_project(
2026            cx,
2027            json!({
2028                "file1.rs": "fn function1() {}\n",
2029                "file2.rs": "fn function2() {}\n",
2030                "file3.rs": "fn function3() {}\n",
2031            }),
2032        )
2033        .await;
2034
2035        let (_, _thread_store, thread, context_store) =
2036            setup_test_environment(cx, project.clone()).await;
2037
2038        // Open files individually
2039        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2040            .await
2041            .unwrap();
2042        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2043            .await
2044            .unwrap();
2045        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2046            .await
2047            .unwrap();
2048
2049        // Get the context objects
2050        let contexts = context_store.update(cx, |store, _| store.context().clone());
2051        assert_eq!(contexts.len(), 3);
2052
2053        // First message with context 1
2054        let message1_id = thread.update(cx, |thread, cx| {
2055            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2056        });
2057
2058        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2059        let message2_id = thread.update(cx, |thread, cx| {
2060            thread.insert_user_message(
2061                "Message 2",
2062                vec![contexts[0].clone(), contexts[1].clone()],
2063                None,
2064                cx,
2065            )
2066        });
2067
2068        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2069        let message3_id = thread.update(cx, |thread, cx| {
2070            thread.insert_user_message(
2071                "Message 3",
2072                vec![
2073                    contexts[0].clone(),
2074                    contexts[1].clone(),
2075                    contexts[2].clone(),
2076                ],
2077                None,
2078                cx,
2079            )
2080        });
2081
2082        // Check what contexts are included in each message
2083        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2084            (
2085                thread.message(message1_id).unwrap().clone(),
2086                thread.message(message2_id).unwrap().clone(),
2087                thread.message(message3_id).unwrap().clone(),
2088            )
2089        });
2090
2091        // First message should include context 1
2092        assert!(message1.context.contains("file1.rs"));
2093
2094        // Second message should include only context 2 (not 1)
2095        assert!(!message2.context.contains("file1.rs"));
2096        assert!(message2.context.contains("file2.rs"));
2097
2098        // Third message should include only context 3 (not 1 or 2)
2099        assert!(!message3.context.contains("file1.rs"));
2100        assert!(!message3.context.contains("file2.rs"));
2101        assert!(message3.context.contains("file3.rs"));
2102
2103        // Check entire request to make sure all contexts are properly included
2104        let request = thread.read_with(cx, |thread, cx| {
2105            thread.to_completion_request(RequestKind::Chat, cx)
2106        });
2107
2108        // The request should contain all 3 messages
2109        assert_eq!(request.messages.len(), 4);
2110
2111        // Check that the contexts are properly formatted in each message
2112        assert!(request.messages[1].string_contents().contains("file1.rs"));
2113        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2114        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2115
2116        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2117        assert!(request.messages[2].string_contents().contains("file2.rs"));
2118        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2119
2120        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2121        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2122        assert!(request.messages[3].string_contents().contains("file3.rs"));
2123    }
2124
2125    #[gpui::test]
2126    async fn test_message_without_files(cx: &mut TestAppContext) {
2127        init_test_settings(cx);
2128
2129        let project = create_test_project(
2130            cx,
2131            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2132        )
2133        .await;
2134
2135        let (_, _thread_store, thread, _context_store) =
2136            setup_test_environment(cx, project.clone()).await;
2137
2138        // Insert user message without any context (empty context vector)
2139        let message_id = thread.update(cx, |thread, cx| {
2140            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2141        });
2142
2143        // Check content and context in message object
2144        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2145
2146        // Context should be empty when no files are included
2147        assert_eq!(message.role, Role::User);
2148        assert_eq!(message.segments.len(), 1);
2149        assert_eq!(
2150            message.segments[0],
2151            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2152        );
2153        assert_eq!(message.context, "");
2154
2155        // Check message in request
2156        let request = thread.read_with(cx, |thread, cx| {
2157            thread.to_completion_request(RequestKind::Chat, cx)
2158        });
2159
2160        assert_eq!(request.messages.len(), 2);
2161        assert_eq!(
2162            request.messages[1].string_contents(),
2163            "What is the best way to learn Rust?"
2164        );
2165
2166        // Add second message, also without context
2167        let message2_id = thread.update(cx, |thread, cx| {
2168            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2169        });
2170
2171        let message2 =
2172            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2173        assert_eq!(message2.context, "");
2174
2175        // Check that both messages appear in the request
2176        let request = thread.read_with(cx, |thread, cx| {
2177            thread.to_completion_request(RequestKind::Chat, cx)
2178        });
2179
2180        assert_eq!(request.messages.len(), 3);
2181        assert_eq!(
2182            request.messages[1].string_contents(),
2183            "What is the best way to learn Rust?"
2184        );
2185        assert_eq!(
2186            request.messages[2].string_contents(),
2187            "Are there any good books?"
2188        );
2189    }
2190
2191    #[gpui::test]
2192    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2193        init_test_settings(cx);
2194
2195        let project = create_test_project(
2196            cx,
2197            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2198        )
2199        .await;
2200
2201        let (_workspace, _thread_store, thread, context_store) =
2202            setup_test_environment(cx, project.clone()).await;
2203
2204        // Open buffer and add it to context
2205        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2206            .await
2207            .unwrap();
2208
2209        let context =
2210            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2211
2212        // Insert user message with the buffer as context
2213        thread.update(cx, |thread, cx| {
2214            thread.insert_user_message("Explain this code", vec![context], None, cx)
2215        });
2216
2217        // Create a request and check that it doesn't have a stale buffer warning yet
2218        let initial_request = thread.read_with(cx, |thread, cx| {
2219            thread.to_completion_request(RequestKind::Chat, cx)
2220        });
2221
2222        // Make sure we don't have a stale file warning yet
2223        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2224            msg.string_contents()
2225                .contains("These files changed since last read:")
2226        });
2227        assert!(
2228            !has_stale_warning,
2229            "Should not have stale buffer warning before buffer is modified"
2230        );
2231
2232        // Modify the buffer
2233        buffer.update(cx, |buffer, cx| {
2234            // Find a position at the end of line 1
2235            buffer.edit(
2236                [(1..1, "\n    println!(\"Added a new line\");\n")],
2237                None,
2238                cx,
2239            );
2240        });
2241
2242        // Insert another user message without context
2243        thread.update(cx, |thread, cx| {
2244            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2245        });
2246
2247        // Create a new request and check for the stale buffer warning
2248        let new_request = thread.read_with(cx, |thread, cx| {
2249            thread.to_completion_request(RequestKind::Chat, cx)
2250        });
2251
2252        // We should have a stale file warning as the last message
2253        let last_message = new_request
2254            .messages
2255            .last()
2256            .expect("Request should have messages");
2257
2258        // The last message should be the stale buffer notification
2259        assert_eq!(last_message.role, Role::User);
2260
2261        // Check the exact content of the message
2262        let expected_content = "These files changed since last read:\n- code.rs\n";
2263        assert_eq!(
2264            last_message.string_contents(),
2265            expected_content,
2266            "Last message should be exactly the stale buffer notification"
2267        );
2268    }
2269
2270    fn init_test_settings(cx: &mut TestAppContext) {
2271        cx.update(|cx| {
2272            let settings_store = SettingsStore::test(cx);
2273            cx.set_global(settings_store);
2274            language::init(cx);
2275            Project::init_settings(cx);
2276            AssistantSettings::register(cx);
2277            thread_store::init(cx);
2278            workspace::init_settings(cx);
2279            ThemeSettings::register(cx);
2280            ContextServerSettings::register(cx);
2281            EditorSettings::register(cx);
2282        });
2283    }
2284
2285    // Helper to create a test project with test files
2286    async fn create_test_project(
2287        cx: &mut TestAppContext,
2288        files: serde_json::Value,
2289    ) -> Entity<Project> {
2290        let fs = FakeFs::new(cx.executor());
2291        fs.insert_tree(path!("/test"), files).await;
2292        Project::test(fs, [path!("/test").as_ref()], cx).await
2293    }
2294
2295    async fn setup_test_environment(
2296        cx: &mut TestAppContext,
2297        project: Entity<Project>,
2298    ) -> (
2299        Entity<Workspace>,
2300        Entity<ThreadStore>,
2301        Entity<Thread>,
2302        Entity<ContextStore>,
2303    ) {
2304        let (workspace, cx) =
2305            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2306
2307        let thread_store = cx
2308            .update(|_, cx| {
2309                ThreadStore::load(
2310                    project.clone(),
2311                    Arc::default(),
2312                    Arc::new(PromptBuilder::new(None).unwrap()),
2313                    cx,
2314                )
2315            })
2316            .await;
2317
2318        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2319        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2320
2321        (workspace, thread_store, thread, context_store)
2322    }
2323
2324    async fn add_file_to_context(
2325        project: &Entity<Project>,
2326        context_store: &Entity<ContextStore>,
2327        path: &str,
2328        cx: &mut TestAppContext,
2329    ) -> Result<Entity<language::Buffer>> {
2330        let buffer_path = project
2331            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2332            .unwrap();
2333
2334        let buffer = project
2335            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2336            .await
2337            .unwrap();
2338
2339        context_store
2340            .update(cx, |store, cx| {
2341                store.add_file_from_buffer(buffer.clone(), cx)
2342            })
2343            .await?;
2344
2345        Ok(buffer)
2346    }
2347}