thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use agent_rules::load_worktree_rules_file;
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use fs::Fs;
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
  19    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  20    LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  21    PaymentRequiredError, Role, StopReason, TokenUsage,
  22};
  23use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  24use project::{Project, Worktree};
  25use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
  26use schemars::JsonSchema;
  27use serde::{Deserialize, Serialize};
  28use settings::Settings;
  29use util::{ResultExt as _, TryFutureExt as _, post_inc};
  30use uuid::Uuid;
  31
  32use crate::context::{AssistantContext, ContextId, format_context_as_string};
  33use crate::thread_store::{
  34    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  35    SerializedToolUse,
  36};
  37use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  38
  39#[derive(Debug, Clone, Copy)]
  40pub enum RequestKind {
  41    Chat,
  42    /// Used when summarizing a thread.
  43    Summarize,
  44}
  45
  46#[derive(
  47    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  48)]
  49pub struct ThreadId(Arc<str>);
  50
  51impl ThreadId {
  52    pub fn new() -> Self {
  53        Self(Uuid::new_v4().to_string().into())
  54    }
  55}
  56
  57impl std::fmt::Display for ThreadId {
  58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  59        write!(f, "{}", self.0)
  60    }
  61}
  62
  63impl From<&str> for ThreadId {
  64    fn from(value: &str) -> Self {
  65        Self(value.into())
  66    }
  67}
  68
  69#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  70pub struct MessageId(pub(crate) usize);
  71
  72impl MessageId {
  73    fn post_inc(&mut self) -> Self {
  74        Self(post_inc(&mut self.0))
  75    }
  76}
  77
  78/// A message in a [`Thread`].
  79#[derive(Debug, Clone)]
  80pub struct Message {
  81    pub id: MessageId,
  82    pub role: Role,
  83    pub segments: Vec<MessageSegment>,
  84    pub context: String,
  85}
  86
  87impl Message {
  88    /// Returns whether the message contains any meaningful text that should be displayed
  89    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  90    pub fn should_display_content(&self) -> bool {
  91        self.segments.iter().all(|segment| segment.should_display())
  92    }
  93
  94    pub fn push_thinking(&mut self, text: &str) {
  95        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  96            segment.push_str(text);
  97        } else {
  98            self.segments
  99                .push(MessageSegment::Thinking(text.to_string()));
 100        }
 101    }
 102
 103    pub fn push_text(&mut self, text: &str) {
 104        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 105            segment.push_str(text);
 106        } else {
 107            self.segments.push(MessageSegment::Text(text.to_string()));
 108        }
 109    }
 110
 111    pub fn to_string(&self) -> String {
 112        let mut result = String::new();
 113
 114        if !self.context.is_empty() {
 115            result.push_str(&self.context);
 116        }
 117
 118        for segment in &self.segments {
 119            match segment {
 120                MessageSegment::Text(text) => result.push_str(text),
 121                MessageSegment::Thinking(text) => {
 122                    result.push_str("<think>");
 123                    result.push_str(text);
 124                    result.push_str("</think>");
 125                }
 126            }
 127        }
 128
 129        result
 130    }
 131}
 132
 133#[derive(Debug, Clone, PartialEq, Eq)]
 134pub enum MessageSegment {
 135    Text(String),
 136    Thinking(String),
 137}
 138
 139impl MessageSegment {
 140    pub fn text_mut(&mut self) -> &mut String {
 141        match self {
 142            Self::Text(text) => text,
 143            Self::Thinking(text) => text,
 144        }
 145    }
 146
 147    pub fn should_display(&self) -> bool {
 148        // We add USING_TOOL_MARKER when making a request that includes tool uses
 149        // without non-whitespace text around them, and this can cause the model
 150        // to mimic the pattern, so we consider those segments not displayable.
 151        match self {
 152            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 153            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 154        }
 155    }
 156}
 157
 158#[derive(Debug, Clone, Serialize, Deserialize)]
 159pub struct ProjectSnapshot {
 160    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 161    pub unsaved_buffer_paths: Vec<String>,
 162    pub timestamp: DateTime<Utc>,
 163}
 164
 165#[derive(Debug, Clone, Serialize, Deserialize)]
 166pub struct WorktreeSnapshot {
 167    pub worktree_path: String,
 168    pub git_state: Option<GitState>,
 169}
 170
 171#[derive(Debug, Clone, Serialize, Deserialize)]
 172pub struct GitState {
 173    pub remote_url: Option<String>,
 174    pub head_sha: Option<String>,
 175    pub current_branch: Option<String>,
 176    pub diff: Option<String>,
 177}
 178
 179#[derive(Clone)]
 180pub struct ThreadCheckpoint {
 181    message_id: MessageId,
 182    git_checkpoint: GitStoreCheckpoint,
 183}
 184
 185#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 186pub enum ThreadFeedback {
 187    Positive,
 188    Negative,
 189}
 190
 191pub enum LastRestoreCheckpoint {
 192    Pending {
 193        message_id: MessageId,
 194    },
 195    Error {
 196        message_id: MessageId,
 197        error: String,
 198    },
 199}
 200
 201impl LastRestoreCheckpoint {
 202    pub fn message_id(&self) -> MessageId {
 203        match self {
 204            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 205            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 206        }
 207    }
 208}
 209
 210#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 211pub enum DetailedSummaryState {
 212    #[default]
 213    NotGenerated,
 214    Generating {
 215        message_id: MessageId,
 216    },
 217    Generated {
 218        text: SharedString,
 219        message_id: MessageId,
 220    },
 221}
 222
 223#[derive(Default)]
 224pub struct TotalTokenUsage {
 225    pub total: usize,
 226    pub max: usize,
 227    pub ratio: TokenUsageRatio,
 228}
 229
 230#[derive(Default, PartialEq, Eq)]
 231pub enum TokenUsageRatio {
 232    #[default]
 233    Normal,
 234    Warning,
 235    Exceeded,
 236}
 237
 238/// A thread of conversation with the LLM.
 239pub struct Thread {
 240    id: ThreadId,
 241    updated_at: DateTime<Utc>,
 242    summary: Option<SharedString>,
 243    pending_summary: Task<Option<()>>,
 244    detailed_summary_state: DetailedSummaryState,
 245    messages: Vec<Message>,
 246    next_message_id: MessageId,
 247    context: BTreeMap<ContextId, AssistantContext>,
 248    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 249    system_prompt_context: Option<AssistantSystemPromptContext>,
 250    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 251    completion_count: usize,
 252    pending_completions: Vec<PendingCompletion>,
 253    project: Entity<Project>,
 254    prompt_builder: Arc<PromptBuilder>,
 255    tools: Arc<ToolWorkingSet>,
 256    tool_use: ToolUseState,
 257    action_log: Entity<ActionLog>,
 258    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 259    pending_checkpoint: Option<ThreadCheckpoint>,
 260    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 261    cumulative_token_usage: TokenUsage,
 262    feedback: Option<ThreadFeedback>,
 263    message_feedback: HashMap<MessageId, ThreadFeedback>,
 264}
 265
 266impl Thread {
 267    pub fn new(
 268        project: Entity<Project>,
 269        tools: Arc<ToolWorkingSet>,
 270        prompt_builder: Arc<PromptBuilder>,
 271        cx: &mut Context<Self>,
 272    ) -> Self {
 273        Self {
 274            id: ThreadId::new(),
 275            updated_at: Utc::now(),
 276            summary: None,
 277            pending_summary: Task::ready(None),
 278            detailed_summary_state: DetailedSummaryState::NotGenerated,
 279            messages: Vec::new(),
 280            next_message_id: MessageId(0),
 281            context: BTreeMap::default(),
 282            context_by_message: HashMap::default(),
 283            system_prompt_context: None,
 284            checkpoints_by_message: HashMap::default(),
 285            completion_count: 0,
 286            pending_completions: Vec::new(),
 287            project: project.clone(),
 288            prompt_builder,
 289            tools: tools.clone(),
 290            last_restore_checkpoint: None,
 291            pending_checkpoint: None,
 292            tool_use: ToolUseState::new(tools.clone()),
 293            action_log: cx.new(|_| ActionLog::new(project.clone())),
 294            initial_project_snapshot: {
 295                let project_snapshot = Self::project_snapshot(project, cx);
 296                cx.foreground_executor()
 297                    .spawn(async move { Some(project_snapshot.await) })
 298                    .shared()
 299            },
 300            cumulative_token_usage: TokenUsage::default(),
 301            feedback: None,
 302            message_feedback: HashMap::default(),
 303        }
 304    }
 305
 306    pub fn deserialize(
 307        id: ThreadId,
 308        serialized: SerializedThread,
 309        project: Entity<Project>,
 310        tools: Arc<ToolWorkingSet>,
 311        prompt_builder: Arc<PromptBuilder>,
 312        cx: &mut Context<Self>,
 313    ) -> Self {
 314        let next_message_id = MessageId(
 315            serialized
 316                .messages
 317                .last()
 318                .map(|message| message.id.0 + 1)
 319                .unwrap_or(0),
 320        );
 321        let tool_use =
 322            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 323
 324        Self {
 325            id,
 326            updated_at: serialized.updated_at,
 327            summary: Some(serialized.summary),
 328            pending_summary: Task::ready(None),
 329            detailed_summary_state: serialized.detailed_summary_state,
 330            messages: serialized
 331                .messages
 332                .into_iter()
 333                .map(|message| Message {
 334                    id: message.id,
 335                    role: message.role,
 336                    segments: message
 337                        .segments
 338                        .into_iter()
 339                        .map(|segment| match segment {
 340                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 341                            SerializedMessageSegment::Thinking { text } => {
 342                                MessageSegment::Thinking(text)
 343                            }
 344                        })
 345                        .collect(),
 346                    context: message.context,
 347                })
 348                .collect(),
 349            next_message_id,
 350            context: BTreeMap::default(),
 351            context_by_message: HashMap::default(),
 352            system_prompt_context: None,
 353            checkpoints_by_message: HashMap::default(),
 354            completion_count: 0,
 355            pending_completions: Vec::new(),
 356            last_restore_checkpoint: None,
 357            pending_checkpoint: None,
 358            project: project.clone(),
 359            prompt_builder,
 360            tools,
 361            tool_use,
 362            action_log: cx.new(|_| ActionLog::new(project)),
 363            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 364            cumulative_token_usage: serialized.cumulative_token_usage,
 365            feedback: None,
 366            message_feedback: HashMap::default(),
 367        }
 368    }
 369
 370    pub fn id(&self) -> &ThreadId {
 371        &self.id
 372    }
 373
 374    pub fn is_empty(&self) -> bool {
 375        self.messages.is_empty()
 376    }
 377
 378    pub fn updated_at(&self) -> DateTime<Utc> {
 379        self.updated_at
 380    }
 381
 382    pub fn touch_updated_at(&mut self) {
 383        self.updated_at = Utc::now();
 384    }
 385
 386    pub fn summary(&self) -> Option<SharedString> {
 387        self.summary.clone()
 388    }
 389
 390    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 391
 392    pub fn summary_or_default(&self) -> SharedString {
 393        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 394    }
 395
 396    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 397        let Some(current_summary) = &self.summary else {
 398            // Don't allow setting summary until generated
 399            return;
 400        };
 401
 402        let mut new_summary = new_summary.into();
 403
 404        if new_summary.is_empty() {
 405            new_summary = Self::DEFAULT_SUMMARY;
 406        }
 407
 408        if current_summary != &new_summary {
 409            self.summary = Some(new_summary);
 410            cx.emit(ThreadEvent::SummaryChanged);
 411        }
 412    }
 413
 414    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 415        self.latest_detailed_summary()
 416            .unwrap_or_else(|| self.text().into())
 417    }
 418
 419    fn latest_detailed_summary(&self) -> Option<SharedString> {
 420        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 421            Some(text.clone())
 422        } else {
 423            None
 424        }
 425    }
 426
 427    pub fn message(&self, id: MessageId) -> Option<&Message> {
 428        self.messages.iter().find(|message| message.id == id)
 429    }
 430
 431    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 432        self.messages.iter()
 433    }
 434
 435    pub fn is_generating(&self) -> bool {
 436        !self.pending_completions.is_empty() || !self.all_tools_finished()
 437    }
 438
 439    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 440        &self.tools
 441    }
 442
 443    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 444        self.tool_use
 445            .pending_tool_uses()
 446            .into_iter()
 447            .find(|tool_use| &tool_use.id == id)
 448    }
 449
 450    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 451        self.tool_use
 452            .pending_tool_uses()
 453            .into_iter()
 454            .filter(|tool_use| tool_use.status.needs_confirmation())
 455    }
 456
 457    pub fn has_pending_tool_uses(&self) -> bool {
 458        !self.tool_use.pending_tool_uses().is_empty()
 459    }
 460
 461    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 462        self.checkpoints_by_message.get(&id).cloned()
 463    }
 464
 465    pub fn restore_checkpoint(
 466        &mut self,
 467        checkpoint: ThreadCheckpoint,
 468        cx: &mut Context<Self>,
 469    ) -> Task<Result<()>> {
 470        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 471            message_id: checkpoint.message_id,
 472        });
 473        cx.emit(ThreadEvent::CheckpointChanged);
 474        cx.notify();
 475
 476        let git_store = self.project().read(cx).git_store().clone();
 477        let restore = git_store.update(cx, |git_store, cx| {
 478            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 479        });
 480
 481        cx.spawn(async move |this, cx| {
 482            let result = restore.await;
 483            this.update(cx, |this, cx| {
 484                if let Err(err) = result.as_ref() {
 485                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 486                        message_id: checkpoint.message_id,
 487                        error: err.to_string(),
 488                    });
 489                } else {
 490                    this.truncate(checkpoint.message_id, cx);
 491                    this.last_restore_checkpoint = None;
 492                }
 493                this.pending_checkpoint = None;
 494                cx.emit(ThreadEvent::CheckpointChanged);
 495                cx.notify();
 496            })?;
 497            result
 498        })
 499    }
 500
 501    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 502        let pending_checkpoint = if self.is_generating() {
 503            return;
 504        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 505            checkpoint
 506        } else {
 507            return;
 508        };
 509
 510        let git_store = self.project.read(cx).git_store().clone();
 511        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 512        cx.spawn(async move |this, cx| match final_checkpoint.await {
 513            Ok(final_checkpoint) => {
 514                let equal = git_store
 515                    .update(cx, |store, cx| {
 516                        store.compare_checkpoints(
 517                            pending_checkpoint.git_checkpoint.clone(),
 518                            final_checkpoint.clone(),
 519                            cx,
 520                        )
 521                    })?
 522                    .await
 523                    .unwrap_or(false);
 524
 525                if equal {
 526                    git_store
 527                        .update(cx, |store, cx| {
 528                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 529                        })?
 530                        .detach();
 531                } else {
 532                    this.update(cx, |this, cx| {
 533                        this.insert_checkpoint(pending_checkpoint, cx)
 534                    })?;
 535                }
 536
 537                git_store
 538                    .update(cx, |store, cx| {
 539                        store.delete_checkpoint(final_checkpoint, cx)
 540                    })?
 541                    .detach();
 542
 543                Ok(())
 544            }
 545            Err(_) => this.update(cx, |this, cx| {
 546                this.insert_checkpoint(pending_checkpoint, cx)
 547            }),
 548        })
 549        .detach();
 550    }
 551
 552    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 553        self.checkpoints_by_message
 554            .insert(checkpoint.message_id, checkpoint);
 555        cx.emit(ThreadEvent::CheckpointChanged);
 556        cx.notify();
 557    }
 558
 559    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 560        self.last_restore_checkpoint.as_ref()
 561    }
 562
 563    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 564        let Some(message_ix) = self
 565            .messages
 566            .iter()
 567            .rposition(|message| message.id == message_id)
 568        else {
 569            return;
 570        };
 571        for deleted_message in self.messages.drain(message_ix..) {
 572            self.context_by_message.remove(&deleted_message.id);
 573            self.checkpoints_by_message.remove(&deleted_message.id);
 574        }
 575        cx.notify();
 576    }
 577
 578    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 579        self.context_by_message
 580            .get(&id)
 581            .into_iter()
 582            .flat_map(|context| {
 583                context
 584                    .iter()
 585                    .filter_map(|context_id| self.context.get(&context_id))
 586            })
 587    }
 588
 589    /// Returns whether all of the tool uses have finished running.
 590    pub fn all_tools_finished(&self) -> bool {
 591        // If the only pending tool uses left are the ones with errors, then
 592        // that means that we've finished running all of the pending tools.
 593        self.tool_use
 594            .pending_tool_uses()
 595            .iter()
 596            .all(|tool_use| tool_use.status.is_error())
 597    }
 598
 599    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 600        self.tool_use.tool_uses_for_message(id, cx)
 601    }
 602
 603    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 604        self.tool_use.tool_results_for_message(id)
 605    }
 606
 607    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 608        self.tool_use.tool_result(id)
 609    }
 610
 611    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 612        self.tool_use.message_has_tool_results(message_id)
 613    }
 614
 615    pub fn insert_user_message(
 616        &mut self,
 617        text: impl Into<String>,
 618        context: Vec<AssistantContext>,
 619        git_checkpoint: Option<GitStoreCheckpoint>,
 620        cx: &mut Context<Self>,
 621    ) -> MessageId {
 622        let text = text.into();
 623
 624        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 625
 626        // Filter out contexts that have already been included in previous messages
 627        let new_context: Vec<_> = context
 628            .into_iter()
 629            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 630            .collect();
 631
 632        if !new_context.is_empty() {
 633            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 634                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 635                    message.context = context_string;
 636                }
 637            }
 638
 639            self.action_log.update(cx, |log, cx| {
 640                // Track all buffers added as context
 641                for ctx in &new_context {
 642                    match ctx {
 643                        AssistantContext::File(file_ctx) => {
 644                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 645                        }
 646                        AssistantContext::Directory(dir_ctx) => {
 647                            for context_buffer in &dir_ctx.context_buffers {
 648                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 649                            }
 650                        }
 651                        AssistantContext::Symbol(symbol_ctx) => {
 652                            log.buffer_added_as_context(
 653                                symbol_ctx.context_symbol.buffer.clone(),
 654                                cx,
 655                            );
 656                        }
 657                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 658                    }
 659                }
 660            });
 661        }
 662
 663        let context_ids = new_context
 664            .iter()
 665            .map(|context| context.id())
 666            .collect::<Vec<_>>();
 667        self.context.extend(
 668            new_context
 669                .into_iter()
 670                .map(|context| (context.id(), context)),
 671        );
 672        self.context_by_message.insert(message_id, context_ids);
 673
 674        if let Some(git_checkpoint) = git_checkpoint {
 675            self.pending_checkpoint = Some(ThreadCheckpoint {
 676                message_id,
 677                git_checkpoint,
 678            });
 679        }
 680        message_id
 681    }
 682
 683    pub fn insert_message(
 684        &mut self,
 685        role: Role,
 686        segments: Vec<MessageSegment>,
 687        cx: &mut Context<Self>,
 688    ) -> MessageId {
 689        let id = self.next_message_id.post_inc();
 690        self.messages.push(Message {
 691            id,
 692            role,
 693            segments,
 694            context: String::new(),
 695        });
 696        self.touch_updated_at();
 697        cx.emit(ThreadEvent::MessageAdded(id));
 698        id
 699    }
 700
 701    pub fn edit_message(
 702        &mut self,
 703        id: MessageId,
 704        new_role: Role,
 705        new_segments: Vec<MessageSegment>,
 706        cx: &mut Context<Self>,
 707    ) -> bool {
 708        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 709            return false;
 710        };
 711        message.role = new_role;
 712        message.segments = new_segments;
 713        self.touch_updated_at();
 714        cx.emit(ThreadEvent::MessageEdited(id));
 715        true
 716    }
 717
 718    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 719        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 720            return false;
 721        };
 722        self.messages.remove(index);
 723        self.context_by_message.remove(&id);
 724        self.touch_updated_at();
 725        cx.emit(ThreadEvent::MessageDeleted(id));
 726        true
 727    }
 728
 729    /// Returns the representation of this [`Thread`] in a textual form.
 730    ///
 731    /// This is the representation we use when attaching a thread as context to another thread.
 732    pub fn text(&self) -> String {
 733        let mut text = String::new();
 734
 735        for message in &self.messages {
 736            text.push_str(match message.role {
 737                language_model::Role::User => "User:",
 738                language_model::Role::Assistant => "Assistant:",
 739                language_model::Role::System => "System:",
 740            });
 741            text.push('\n');
 742
 743            for segment in &message.segments {
 744                match segment {
 745                    MessageSegment::Text(content) => text.push_str(content),
 746                    MessageSegment::Thinking(content) => {
 747                        text.push_str(&format!("<think>{}</think>", content))
 748                    }
 749                }
 750            }
 751            text.push('\n');
 752        }
 753
 754        text
 755    }
 756
 757    /// Serializes this thread into a format for storage or telemetry.
 758    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 759        let initial_project_snapshot = self.initial_project_snapshot.clone();
 760        cx.spawn(async move |this, cx| {
 761            let initial_project_snapshot = initial_project_snapshot.await;
 762            this.read_with(cx, |this, cx| SerializedThread {
 763                version: SerializedThread::VERSION.to_string(),
 764                summary: this.summary_or_default(),
 765                updated_at: this.updated_at(),
 766                messages: this
 767                    .messages()
 768                    .map(|message| SerializedMessage {
 769                        id: message.id,
 770                        role: message.role,
 771                        segments: message
 772                            .segments
 773                            .iter()
 774                            .map(|segment| match segment {
 775                                MessageSegment::Text(text) => {
 776                                    SerializedMessageSegment::Text { text: text.clone() }
 777                                }
 778                                MessageSegment::Thinking(text) => {
 779                                    SerializedMessageSegment::Thinking { text: text.clone() }
 780                                }
 781                            })
 782                            .collect(),
 783                        tool_uses: this
 784                            .tool_uses_for_message(message.id, cx)
 785                            .into_iter()
 786                            .map(|tool_use| SerializedToolUse {
 787                                id: tool_use.id,
 788                                name: tool_use.name,
 789                                input: tool_use.input,
 790                            })
 791                            .collect(),
 792                        tool_results: this
 793                            .tool_results_for_message(message.id)
 794                            .into_iter()
 795                            .map(|tool_result| SerializedToolResult {
 796                                tool_use_id: tool_result.tool_use_id.clone(),
 797                                is_error: tool_result.is_error,
 798                                content: tool_result.content.clone(),
 799                            })
 800                            .collect(),
 801                        context: message.context.clone(),
 802                    })
 803                    .collect(),
 804                initial_project_snapshot,
 805                cumulative_token_usage: this.cumulative_token_usage.clone(),
 806                detailed_summary_state: this.detailed_summary_state.clone(),
 807            })
 808        })
 809    }
 810
 811    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 812        self.system_prompt_context = Some(context);
 813    }
 814
 815    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 816        &self.system_prompt_context
 817    }
 818
 819    pub fn load_system_prompt_context(
 820        &self,
 821        cx: &App,
 822    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 823        let project = self.project.read(cx);
 824        let tasks = project
 825            .visible_worktrees(cx)
 826            .map(|worktree| {
 827                Self::load_worktree_info_for_system_prompt(
 828                    project.fs().clone(),
 829                    worktree.read(cx),
 830                    cx,
 831                )
 832            })
 833            .collect::<Vec<_>>();
 834
 835        cx.spawn(async |_cx| {
 836            let results = futures::future::join_all(tasks).await;
 837            let mut first_err = None;
 838            let worktrees = results
 839                .into_iter()
 840                .map(|(worktree, err)| {
 841                    if first_err.is_none() && err.is_some() {
 842                        first_err = err;
 843                    }
 844                    worktree
 845                })
 846                .collect::<Vec<_>>();
 847            (AssistantSystemPromptContext::new(worktrees), first_err)
 848        })
 849    }
 850
 851    fn load_worktree_info_for_system_prompt(
 852        fs: Arc<dyn Fs>,
 853        worktree: &Worktree,
 854        cx: &App,
 855    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 856        let root_name = worktree.root_name().into();
 857        let abs_path = worktree.abs_path();
 858
 859        let rules_task = load_worktree_rules_file(fs, worktree, cx);
 860        let Some(rules_task) = rules_task else {
 861            return Task::ready((
 862                WorktreeInfoForSystemPrompt {
 863                    root_name,
 864                    abs_path,
 865                    rules_file: None,
 866                },
 867                None,
 868            ));
 869        };
 870
 871        cx.spawn(async move |_| {
 872            let (rules_file, rules_file_error) = match rules_task.await {
 873                Ok(rules_file) => (Some(rules_file), None),
 874                Err(err) => (
 875                    None,
 876                    Some(ThreadError::Message {
 877                        header: "Error loading rules file".into(),
 878                        message: format!("{err}").into(),
 879                    }),
 880                ),
 881            };
 882            let worktree_info = WorktreeInfoForSystemPrompt {
 883                root_name,
 884                abs_path,
 885                rules_file,
 886            };
 887            (worktree_info, rules_file_error)
 888        })
 889    }
 890
 891    pub fn send_to_model(
 892        &mut self,
 893        model: Arc<dyn LanguageModel>,
 894        request_kind: RequestKind,
 895        cx: &mut Context<Self>,
 896    ) {
 897        let mut request = self.to_completion_request(request_kind, cx);
 898        if model.supports_tools() {
 899            request.tools = {
 900                let mut tools = Vec::new();
 901                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 902                    LanguageModelRequestTool {
 903                        name: tool.name(),
 904                        description: tool.description(),
 905                        input_schema: tool.input_schema(model.tool_input_format()),
 906                    }
 907                }));
 908
 909                tools
 910            };
 911        }
 912
 913        self.stream_completion(request, model, cx);
 914    }
 915
 916    pub fn used_tools_since_last_user_message(&self) -> bool {
 917        for message in self.messages.iter().rev() {
 918            if self.tool_use.message_has_tool_results(message.id) {
 919                return true;
 920            } else if message.role == Role::User {
 921                return false;
 922            }
 923        }
 924
 925        false
 926    }
 927
 928    pub fn to_completion_request(
 929        &self,
 930        request_kind: RequestKind,
 931        cx: &App,
 932    ) -> LanguageModelRequest {
 933        let mut request = LanguageModelRequest {
 934            messages: vec![],
 935            tools: Vec::new(),
 936            stop: Vec::new(),
 937            temperature: None,
 938        };
 939
 940        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 941            if let Some(system_prompt) = self
 942                .prompt_builder
 943                .generate_assistant_system_prompt(system_prompt_context)
 944                .context("failed to generate assistant system prompt")
 945                .log_err()
 946            {
 947                request.messages.push(LanguageModelRequestMessage {
 948                    role: Role::System,
 949                    content: vec![MessageContent::Text(system_prompt)],
 950                    cache: true,
 951                });
 952            }
 953        } else {
 954            log::error!("system_prompt_context not set.")
 955        }
 956
 957        for message in &self.messages {
 958            let mut request_message = LanguageModelRequestMessage {
 959                role: message.role,
 960                content: Vec::new(),
 961                cache: false,
 962            };
 963
 964            match request_kind {
 965                RequestKind::Chat => {
 966                    self.tool_use
 967                        .attach_tool_results(message.id, &mut request_message);
 968                }
 969                RequestKind::Summarize => {
 970                    // We don't care about tool use during summarization.
 971                    if self.tool_use.message_has_tool_results(message.id) {
 972                        continue;
 973                    }
 974                }
 975            }
 976
 977            if !message.segments.is_empty() {
 978                request_message
 979                    .content
 980                    .push(MessageContent::Text(message.to_string()));
 981            }
 982
 983            match request_kind {
 984                RequestKind::Chat => {
 985                    self.tool_use
 986                        .attach_tool_uses(message.id, &mut request_message);
 987                }
 988                RequestKind::Summarize => {
 989                    // We don't care about tool use during summarization.
 990                }
 991            };
 992
 993            request.messages.push(request_message);
 994        }
 995
 996        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 997        if let Some(last) = request.messages.last_mut() {
 998            last.cache = true;
 999        }
1000
1001        self.attached_tracked_files_state(&mut request.messages, cx);
1002
1003        request
1004    }
1005
1006    fn attached_tracked_files_state(
1007        &self,
1008        messages: &mut Vec<LanguageModelRequestMessage>,
1009        cx: &App,
1010    ) {
1011        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1012
1013        let mut stale_message = String::new();
1014
1015        let action_log = self.action_log.read(cx);
1016
1017        for stale_file in action_log.stale_buffers(cx) {
1018            let Some(file) = stale_file.read(cx).file() else {
1019                continue;
1020            };
1021
1022            if stale_message.is_empty() {
1023                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1024            }
1025
1026            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1027        }
1028
1029        let mut content = Vec::with_capacity(2);
1030
1031        if !stale_message.is_empty() {
1032            content.push(stale_message.into());
1033        }
1034
1035        if action_log.has_edited_files_since_project_diagnostics_check() {
1036            content.push(
1037                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1038                and fix all errors AND warnings you introduced! \
1039                DO NOT mention you're going to do this until you're done."
1040                    .into(),
1041            );
1042        }
1043
1044        if !content.is_empty() {
1045            let context_message = LanguageModelRequestMessage {
1046                role: Role::User,
1047                content,
1048                cache: false,
1049            };
1050
1051            messages.push(context_message);
1052        }
1053    }
1054
1055    pub fn stream_completion(
1056        &mut self,
1057        request: LanguageModelRequest,
1058        model: Arc<dyn LanguageModel>,
1059        cx: &mut Context<Self>,
1060    ) {
1061        let pending_completion_id = post_inc(&mut self.completion_count);
1062
1063        let task = cx.spawn(async move |thread, cx| {
1064            let stream = model.stream_completion(request, &cx);
1065            let initial_token_usage =
1066                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1067            let stream_completion = async {
1068                let mut events = stream.await?;
1069                let mut stop_reason = StopReason::EndTurn;
1070                let mut current_token_usage = TokenUsage::default();
1071
1072                while let Some(event) = events.next().await {
1073                    let event = event?;
1074
1075                    thread.update(cx, |thread, cx| {
1076                        match event {
1077                            LanguageModelCompletionEvent::StartMessage { .. } => {
1078                                thread.insert_message(
1079                                    Role::Assistant,
1080                                    vec![MessageSegment::Text(String::new())],
1081                                    cx,
1082                                );
1083                            }
1084                            LanguageModelCompletionEvent::Stop(reason) => {
1085                                stop_reason = reason;
1086                            }
1087                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1088                                thread.cumulative_token_usage =
1089                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1090                                        - current_token_usage.clone();
1091                                current_token_usage = token_usage;
1092                            }
1093                            LanguageModelCompletionEvent::Text(chunk) => {
1094                                if let Some(last_message) = thread.messages.last_mut() {
1095                                    if last_message.role == Role::Assistant {
1096                                        last_message.push_text(&chunk);
1097                                        cx.emit(ThreadEvent::StreamedAssistantText(
1098                                            last_message.id,
1099                                            chunk,
1100                                        ));
1101                                    } else {
1102                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1103                                        // of a new Assistant response.
1104                                        //
1105                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1106                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1107                                        thread.insert_message(
1108                                            Role::Assistant,
1109                                            vec![MessageSegment::Text(chunk.to_string())],
1110                                            cx,
1111                                        );
1112                                    };
1113                                }
1114                            }
1115                            LanguageModelCompletionEvent::Thinking(chunk) => {
1116                                if let Some(last_message) = thread.messages.last_mut() {
1117                                    if last_message.role == Role::Assistant {
1118                                        last_message.push_thinking(&chunk);
1119                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1120                                            last_message.id,
1121                                            chunk,
1122                                        ));
1123                                    } else {
1124                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1125                                        // of a new Assistant response.
1126                                        //
1127                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1128                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1129                                        thread.insert_message(
1130                                            Role::Assistant,
1131                                            vec![MessageSegment::Thinking(chunk.to_string())],
1132                                            cx,
1133                                        );
1134                                    };
1135                                }
1136                            }
1137                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1138                                let last_assistant_message_id = thread
1139                                    .messages
1140                                    .iter_mut()
1141                                    .rfind(|message| message.role == Role::Assistant)
1142                                    .map(|message| message.id)
1143                                    .unwrap_or_else(|| {
1144                                        thread.insert_message(Role::Assistant, vec![], cx)
1145                                    });
1146
1147                                thread.tool_use.request_tool_use(
1148                                    last_assistant_message_id,
1149                                    tool_use,
1150                                    cx,
1151                                );
1152                            }
1153                        }
1154
1155                        thread.touch_updated_at();
1156                        cx.emit(ThreadEvent::StreamedCompletion);
1157                        cx.notify();
1158                    })?;
1159
1160                    smol::future::yield_now().await;
1161                }
1162
1163                thread.update(cx, |thread, cx| {
1164                    thread
1165                        .pending_completions
1166                        .retain(|completion| completion.id != pending_completion_id);
1167
1168                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1169                        thread.summarize(cx);
1170                    }
1171                })?;
1172
1173                anyhow::Ok(stop_reason)
1174            };
1175
1176            let result = stream_completion.await;
1177
1178            thread
1179                .update(cx, |thread, cx| {
1180                    thread.finalize_pending_checkpoint(cx);
1181                    match result.as_ref() {
1182                        Ok(stop_reason) => match stop_reason {
1183                            StopReason::ToolUse => {
1184                                let tool_uses = thread.use_pending_tools(cx);
1185                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1186                            }
1187                            StopReason::EndTurn => {}
1188                            StopReason::MaxTokens => {}
1189                        },
1190                        Err(error) => {
1191                            if error.is::<PaymentRequiredError>() {
1192                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1193                            } else if error.is::<MaxMonthlySpendReachedError>() {
1194                                cx.emit(ThreadEvent::ShowError(
1195                                    ThreadError::MaxMonthlySpendReached,
1196                                ));
1197                            } else {
1198                                let error_message = error
1199                                    .chain()
1200                                    .map(|err| err.to_string())
1201                                    .collect::<Vec<_>>()
1202                                    .join("\n");
1203                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1204                                    header: "Error interacting with language model".into(),
1205                                    message: SharedString::from(error_message.clone()),
1206                                }));
1207                            }
1208
1209                            thread.cancel_last_completion(cx);
1210                        }
1211                    }
1212                    cx.emit(ThreadEvent::DoneStreaming);
1213
1214                    if let Ok(initial_usage) = initial_token_usage {
1215                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1216
1217                        telemetry::event!(
1218                            "Assistant Thread Completion",
1219                            thread_id = thread.id().to_string(),
1220                            model = model.telemetry_id(),
1221                            model_provider = model.provider_id().to_string(),
1222                            input_tokens = usage.input_tokens,
1223                            output_tokens = usage.output_tokens,
1224                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1225                            cache_read_input_tokens = usage.cache_read_input_tokens,
1226                        );
1227                    }
1228                })
1229                .ok();
1230        });
1231
1232        self.pending_completions.push(PendingCompletion {
1233            id: pending_completion_id,
1234            _task: task,
1235        });
1236    }
1237
1238    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1239        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1240            return;
1241        };
1242
1243        if !model.provider.is_authenticated(cx) {
1244            return;
1245        }
1246
1247        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1248        request.messages.push(LanguageModelRequestMessage {
1249            role: Role::User,
1250            content: vec![
1251                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1252                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1253                 If the conversation is about a specific subject, include it in the title. \
1254                 Be descriptive. DO NOT speak in the first person."
1255                    .into(),
1256            ],
1257            cache: false,
1258        });
1259
1260        self.pending_summary = cx.spawn(async move |this, cx| {
1261            async move {
1262                let stream = model.model.stream_completion_text(request, &cx);
1263                let mut messages = stream.await?;
1264
1265                let mut new_summary = String::new();
1266                while let Some(message) = messages.stream.next().await {
1267                    let text = message?;
1268                    let mut lines = text.lines();
1269                    new_summary.extend(lines.next());
1270
1271                    // Stop if the LLM generated multiple lines.
1272                    if lines.next().is_some() {
1273                        break;
1274                    }
1275                }
1276
1277                this.update(cx, |this, cx| {
1278                    if !new_summary.is_empty() {
1279                        this.summary = Some(new_summary.into());
1280                    }
1281
1282                    cx.emit(ThreadEvent::SummaryGenerated);
1283                })?;
1284
1285                anyhow::Ok(())
1286            }
1287            .log_err()
1288            .await
1289        });
1290    }
1291
1292    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1293        let last_message_id = self.messages.last().map(|message| message.id)?;
1294
1295        match &self.detailed_summary_state {
1296            DetailedSummaryState::Generating { message_id, .. }
1297            | DetailedSummaryState::Generated { message_id, .. }
1298                if *message_id == last_message_id =>
1299            {
1300                // Already up-to-date
1301                return None;
1302            }
1303            _ => {}
1304        }
1305
1306        let ConfiguredModel { model, provider } =
1307            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1308
1309        if !provider.is_authenticated(cx) {
1310            return None;
1311        }
1312
1313        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1314
1315        request.messages.push(LanguageModelRequestMessage {
1316            role: Role::User,
1317            content: vec![
1318                "Generate a detailed summary of this conversation. Include:\n\
1319                1. A brief overview of what was discussed\n\
1320                2. Key facts or information discovered\n\
1321                3. Outcomes or conclusions reached\n\
1322                4. Any action items or next steps if any\n\
1323                Format it in Markdown with headings and bullet points."
1324                    .into(),
1325            ],
1326            cache: false,
1327        });
1328
1329        let task = cx.spawn(async move |thread, cx| {
1330            let stream = model.stream_completion_text(request, &cx);
1331            let Some(mut messages) = stream.await.log_err() else {
1332                thread
1333                    .update(cx, |this, _cx| {
1334                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1335                    })
1336                    .log_err();
1337
1338                return;
1339            };
1340
1341            let mut new_detailed_summary = String::new();
1342
1343            while let Some(chunk) = messages.stream.next().await {
1344                if let Some(chunk) = chunk.log_err() {
1345                    new_detailed_summary.push_str(&chunk);
1346                }
1347            }
1348
1349            thread
1350                .update(cx, |this, _cx| {
1351                    this.detailed_summary_state = DetailedSummaryState::Generated {
1352                        text: new_detailed_summary.into(),
1353                        message_id: last_message_id,
1354                    };
1355                })
1356                .log_err();
1357        });
1358
1359        self.detailed_summary_state = DetailedSummaryState::Generating {
1360            message_id: last_message_id,
1361        };
1362
1363        Some(task)
1364    }
1365
1366    pub fn is_generating_detailed_summary(&self) -> bool {
1367        matches!(
1368            self.detailed_summary_state,
1369            DetailedSummaryState::Generating { .. }
1370        )
1371    }
1372
1373    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1374        let request = self.to_completion_request(RequestKind::Chat, cx);
1375        let messages = Arc::new(request.messages);
1376        let pending_tool_uses = self
1377            .tool_use
1378            .pending_tool_uses()
1379            .into_iter()
1380            .filter(|tool_use| tool_use.status.is_idle())
1381            .cloned()
1382            .collect::<Vec<_>>();
1383
1384        for tool_use in pending_tool_uses.iter() {
1385            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1386                if tool.needs_confirmation(&tool_use.input, cx)
1387                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1388                {
1389                    self.tool_use.confirm_tool_use(
1390                        tool_use.id.clone(),
1391                        tool_use.ui_text.clone(),
1392                        tool_use.input.clone(),
1393                        messages.clone(),
1394                        tool,
1395                    );
1396                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1397                } else {
1398                    self.run_tool(
1399                        tool_use.id.clone(),
1400                        tool_use.ui_text.clone(),
1401                        tool_use.input.clone(),
1402                        &messages,
1403                        tool,
1404                        cx,
1405                    );
1406                }
1407            }
1408        }
1409
1410        pending_tool_uses
1411    }
1412
1413    pub fn run_tool(
1414        &mut self,
1415        tool_use_id: LanguageModelToolUseId,
1416        ui_text: impl Into<SharedString>,
1417        input: serde_json::Value,
1418        messages: &[LanguageModelRequestMessage],
1419        tool: Arc<dyn Tool>,
1420        cx: &mut Context<Thread>,
1421    ) {
1422        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1423        self.tool_use
1424            .run_pending_tool(tool_use_id, ui_text.into(), task);
1425    }
1426
1427    fn spawn_tool_use(
1428        &mut self,
1429        tool_use_id: LanguageModelToolUseId,
1430        messages: &[LanguageModelRequestMessage],
1431        input: serde_json::Value,
1432        tool: Arc<dyn Tool>,
1433        cx: &mut Context<Thread>,
1434    ) -> Task<()> {
1435        let tool_name: Arc<str> = tool.name().into();
1436
1437        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1438            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1439        } else {
1440            tool.run(
1441                input,
1442                messages,
1443                self.project.clone(),
1444                self.action_log.clone(),
1445                cx,
1446            )
1447        };
1448
1449        cx.spawn({
1450            async move |thread: WeakEntity<Thread>, cx| {
1451                let output = run_tool.await;
1452
1453                thread
1454                    .update(cx, |thread, cx| {
1455                        let pending_tool_use = thread.tool_use.insert_tool_output(
1456                            tool_use_id.clone(),
1457                            tool_name,
1458                            output,
1459                            cx,
1460                        );
1461                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1462                    })
1463                    .ok();
1464            }
1465        })
1466    }
1467
1468    fn tool_finished(
1469        &mut self,
1470        tool_use_id: LanguageModelToolUseId,
1471        pending_tool_use: Option<PendingToolUse>,
1472        canceled: bool,
1473        cx: &mut Context<Self>,
1474    ) {
1475        if self.all_tools_finished() {
1476            let model_registry = LanguageModelRegistry::read_global(cx);
1477            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1478                self.attach_tool_results(cx);
1479                if !canceled {
1480                    self.send_to_model(model, RequestKind::Chat, cx);
1481                }
1482            }
1483        }
1484
1485        cx.emit(ThreadEvent::ToolFinished {
1486            tool_use_id,
1487            pending_tool_use,
1488        });
1489    }
1490
1491    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1492        // Insert a user message to contain the tool results.
1493        self.insert_user_message(
1494            // TODO: Sending up a user message without any content results in the model sending back
1495            // responses that also don't have any content. We currently don't handle this case well,
1496            // so for now we provide some text to keep the model on track.
1497            "Here are the tool results.",
1498            Vec::new(),
1499            None,
1500            cx,
1501        );
1502    }
1503
1504    /// Cancels the last pending completion, if there are any pending.
1505    ///
1506    /// Returns whether a completion was canceled.
1507    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1508        let canceled = if self.pending_completions.pop().is_some() {
1509            true
1510        } else {
1511            let mut canceled = false;
1512            for pending_tool_use in self.tool_use.cancel_pending() {
1513                canceled = true;
1514                self.tool_finished(
1515                    pending_tool_use.id.clone(),
1516                    Some(pending_tool_use),
1517                    true,
1518                    cx,
1519                );
1520            }
1521            canceled
1522        };
1523        self.finalize_pending_checkpoint(cx);
1524        canceled
1525    }
1526
1527    pub fn feedback(&self) -> Option<ThreadFeedback> {
1528        self.feedback
1529    }
1530
1531    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1532        self.message_feedback.get(&message_id).copied()
1533    }
1534
1535    pub fn report_message_feedback(
1536        &mut self,
1537        message_id: MessageId,
1538        feedback: ThreadFeedback,
1539        cx: &mut Context<Self>,
1540    ) -> Task<Result<()>> {
1541        if self.message_feedback.get(&message_id) == Some(&feedback) {
1542            return Task::ready(Ok(()));
1543        }
1544
1545        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1546        let serialized_thread = self.serialize(cx);
1547        let thread_id = self.id().clone();
1548        let client = self.project.read(cx).client();
1549
1550        let enabled_tool_names: Vec<String> = self
1551            .tools()
1552            .enabled_tools(cx)
1553            .iter()
1554            .map(|tool| tool.name().to_string())
1555            .collect();
1556
1557        self.message_feedback.insert(message_id, feedback);
1558
1559        cx.notify();
1560
1561        let message_content = self
1562            .message(message_id)
1563            .map(|msg| msg.to_string())
1564            .unwrap_or_default();
1565
1566        cx.background_spawn(async move {
1567            let final_project_snapshot = final_project_snapshot.await;
1568            let serialized_thread = serialized_thread.await?;
1569            let thread_data =
1570                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1571
1572            let rating = match feedback {
1573                ThreadFeedback::Positive => "positive",
1574                ThreadFeedback::Negative => "negative",
1575            };
1576            telemetry::event!(
1577                "Assistant Thread Rated",
1578                rating,
1579                thread_id,
1580                enabled_tool_names,
1581                message_id = message_id.0,
1582                message_content,
1583                thread_data,
1584                final_project_snapshot
1585            );
1586            client.telemetry().flush_events();
1587
1588            Ok(())
1589        })
1590    }
1591
1592    pub fn report_feedback(
1593        &mut self,
1594        feedback: ThreadFeedback,
1595        cx: &mut Context<Self>,
1596    ) -> Task<Result<()>> {
1597        let last_assistant_message_id = self
1598            .messages
1599            .iter()
1600            .rev()
1601            .find(|msg| msg.role == Role::Assistant)
1602            .map(|msg| msg.id);
1603
1604        if let Some(message_id) = last_assistant_message_id {
1605            self.report_message_feedback(message_id, feedback, cx)
1606        } else {
1607            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1608            let serialized_thread = self.serialize(cx);
1609            let thread_id = self.id().clone();
1610            let client = self.project.read(cx).client();
1611            self.feedback = Some(feedback);
1612            cx.notify();
1613
1614            cx.background_spawn(async move {
1615                let final_project_snapshot = final_project_snapshot.await;
1616                let serialized_thread = serialized_thread.await?;
1617                let thread_data = serde_json::to_value(serialized_thread)
1618                    .unwrap_or_else(|_| serde_json::Value::Null);
1619
1620                let rating = match feedback {
1621                    ThreadFeedback::Positive => "positive",
1622                    ThreadFeedback::Negative => "negative",
1623                };
1624                telemetry::event!(
1625                    "Assistant Thread Rated",
1626                    rating,
1627                    thread_id,
1628                    thread_data,
1629                    final_project_snapshot
1630                );
1631                client.telemetry().flush_events();
1632
1633                Ok(())
1634            })
1635        }
1636    }
1637
1638    /// Create a snapshot of the current project state including git information and unsaved buffers.
1639    fn project_snapshot(
1640        project: Entity<Project>,
1641        cx: &mut Context<Self>,
1642    ) -> Task<Arc<ProjectSnapshot>> {
1643        let git_store = project.read(cx).git_store().clone();
1644        let worktree_snapshots: Vec<_> = project
1645            .read(cx)
1646            .visible_worktrees(cx)
1647            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1648            .collect();
1649
1650        cx.spawn(async move |_, cx| {
1651            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1652
1653            let mut unsaved_buffers = Vec::new();
1654            cx.update(|app_cx| {
1655                let buffer_store = project.read(app_cx).buffer_store();
1656                for buffer_handle in buffer_store.read(app_cx).buffers() {
1657                    let buffer = buffer_handle.read(app_cx);
1658                    if buffer.is_dirty() {
1659                        if let Some(file) = buffer.file() {
1660                            let path = file.path().to_string_lossy().to_string();
1661                            unsaved_buffers.push(path);
1662                        }
1663                    }
1664                }
1665            })
1666            .ok();
1667
1668            Arc::new(ProjectSnapshot {
1669                worktree_snapshots,
1670                unsaved_buffer_paths: unsaved_buffers,
1671                timestamp: Utc::now(),
1672            })
1673        })
1674    }
1675
1676    fn worktree_snapshot(
1677        worktree: Entity<project::Worktree>,
1678        git_store: Entity<GitStore>,
1679        cx: &App,
1680    ) -> Task<WorktreeSnapshot> {
1681        cx.spawn(async move |cx| {
1682            // Get worktree path and snapshot
1683            let worktree_info = cx.update(|app_cx| {
1684                let worktree = worktree.read(app_cx);
1685                let path = worktree.abs_path().to_string_lossy().to_string();
1686                let snapshot = worktree.snapshot();
1687                (path, snapshot)
1688            });
1689
1690            let Ok((worktree_path, _snapshot)) = worktree_info else {
1691                return WorktreeSnapshot {
1692                    worktree_path: String::new(),
1693                    git_state: None,
1694                };
1695            };
1696
1697            let git_state = git_store
1698                .update(cx, |git_store, cx| {
1699                    git_store
1700                        .repositories()
1701                        .values()
1702                        .find(|repo| {
1703                            repo.read(cx)
1704                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1705                                .is_some()
1706                        })
1707                        .cloned()
1708                })
1709                .ok()
1710                .flatten()
1711                .map(|repo| {
1712                    repo.update(cx, |repo, _| {
1713                        let current_branch =
1714                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1715                        repo.send_job(None, |state, _| async move {
1716                            let RepositoryState::Local { backend, .. } = state else {
1717                                return GitState {
1718                                    remote_url: None,
1719                                    head_sha: None,
1720                                    current_branch,
1721                                    diff: None,
1722                                };
1723                            };
1724
1725                            let remote_url = backend.remote_url("origin");
1726                            let head_sha = backend.head_sha();
1727                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1728
1729                            GitState {
1730                                remote_url,
1731                                head_sha,
1732                                current_branch,
1733                                diff,
1734                            }
1735                        })
1736                    })
1737                });
1738
1739            let git_state = match git_state {
1740                Some(git_state) => match git_state.ok() {
1741                    Some(git_state) => git_state.await.ok(),
1742                    None => None,
1743                },
1744                None => None,
1745            };
1746
1747            WorktreeSnapshot {
1748                worktree_path,
1749                git_state,
1750            }
1751        })
1752    }
1753
1754    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1755        let mut markdown = Vec::new();
1756
1757        if let Some(summary) = self.summary() {
1758            writeln!(markdown, "# {summary}\n")?;
1759        };
1760
1761        for message in self.messages() {
1762            writeln!(
1763                markdown,
1764                "## {role}\n",
1765                role = match message.role {
1766                    Role::User => "User",
1767                    Role::Assistant => "Assistant",
1768                    Role::System => "System",
1769                }
1770            )?;
1771
1772            if !message.context.is_empty() {
1773                writeln!(markdown, "{}", message.context)?;
1774            }
1775
1776            for segment in &message.segments {
1777                match segment {
1778                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1779                    MessageSegment::Thinking(text) => {
1780                        writeln!(markdown, "<think>{}</think>\n", text)?
1781                    }
1782                }
1783            }
1784
1785            for tool_use in self.tool_uses_for_message(message.id, cx) {
1786                writeln!(
1787                    markdown,
1788                    "**Use Tool: {} ({})**",
1789                    tool_use.name, tool_use.id
1790                )?;
1791                writeln!(markdown, "```json")?;
1792                writeln!(
1793                    markdown,
1794                    "{}",
1795                    serde_json::to_string_pretty(&tool_use.input)?
1796                )?;
1797                writeln!(markdown, "```")?;
1798            }
1799
1800            for tool_result in self.tool_results_for_message(message.id) {
1801                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1802                if tool_result.is_error {
1803                    write!(markdown, " (Error)")?;
1804                }
1805
1806                writeln!(markdown, "**\n")?;
1807                writeln!(markdown, "{}", tool_result.content)?;
1808            }
1809        }
1810
1811        Ok(String::from_utf8_lossy(&markdown).to_string())
1812    }
1813
1814    pub fn keep_edits_in_range(
1815        &mut self,
1816        buffer: Entity<language::Buffer>,
1817        buffer_range: Range<language::Anchor>,
1818        cx: &mut Context<Self>,
1819    ) {
1820        self.action_log.update(cx, |action_log, cx| {
1821            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1822        });
1823    }
1824
1825    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1826        self.action_log
1827            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1828    }
1829
1830    pub fn reject_edits_in_range(
1831        &mut self,
1832        buffer: Entity<language::Buffer>,
1833        buffer_range: Range<language::Anchor>,
1834        cx: &mut Context<Self>,
1835    ) -> Task<Result<()>> {
1836        self.action_log.update(cx, |action_log, cx| {
1837            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1838        })
1839    }
1840
1841    pub fn action_log(&self) -> &Entity<ActionLog> {
1842        &self.action_log
1843    }
1844
1845    pub fn project(&self) -> &Entity<Project> {
1846        &self.project
1847    }
1848
1849    pub fn cumulative_token_usage(&self) -> TokenUsage {
1850        self.cumulative_token_usage.clone()
1851    }
1852
1853    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1854        let model_registry = LanguageModelRegistry::read_global(cx);
1855        let Some(model) = model_registry.default_model() else {
1856            return TotalTokenUsage::default();
1857        };
1858
1859        let max = model.model.max_token_count();
1860
1861        #[cfg(debug_assertions)]
1862        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1863            .unwrap_or("0.8".to_string())
1864            .parse()
1865            .unwrap();
1866        #[cfg(not(debug_assertions))]
1867        let warning_threshold: f32 = 0.8;
1868
1869        let total = self.cumulative_token_usage.total_tokens() as usize;
1870
1871        let ratio = if total >= max {
1872            TokenUsageRatio::Exceeded
1873        } else if total as f32 / max as f32 >= warning_threshold {
1874            TokenUsageRatio::Warning
1875        } else {
1876            TokenUsageRatio::Normal
1877        };
1878
1879        TotalTokenUsage { total, max, ratio }
1880    }
1881
1882    pub fn deny_tool_use(
1883        &mut self,
1884        tool_use_id: LanguageModelToolUseId,
1885        tool_name: Arc<str>,
1886        cx: &mut Context<Self>,
1887    ) {
1888        let err = Err(anyhow::anyhow!(
1889            "Permission to run tool action denied by user"
1890        ));
1891
1892        self.tool_use
1893            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1894        self.tool_finished(tool_use_id.clone(), None, true, cx);
1895    }
1896}
1897
1898#[derive(Debug, Clone)]
1899pub enum ThreadError {
1900    PaymentRequired,
1901    MaxMonthlySpendReached,
1902    Message {
1903        header: SharedString,
1904        message: SharedString,
1905    },
1906}
1907
1908#[derive(Debug, Clone)]
1909pub enum ThreadEvent {
1910    ShowError(ThreadError),
1911    StreamedCompletion,
1912    StreamedAssistantText(MessageId, String),
1913    StreamedAssistantThinking(MessageId, String),
1914    DoneStreaming,
1915    MessageAdded(MessageId),
1916    MessageEdited(MessageId),
1917    MessageDeleted(MessageId),
1918    SummaryGenerated,
1919    SummaryChanged,
1920    UsePendingTools {
1921        tool_uses: Vec<PendingToolUse>,
1922    },
1923    ToolFinished {
1924        #[allow(unused)]
1925        tool_use_id: LanguageModelToolUseId,
1926        /// The pending tool use that corresponds to this tool.
1927        pending_tool_use: Option<PendingToolUse>,
1928    },
1929    CheckpointChanged,
1930    ToolConfirmationNeeded,
1931}
1932
1933impl EventEmitter<ThreadEvent> for Thread {}
1934
1935struct PendingCompletion {
1936    id: usize,
1937    _task: Task<()>,
1938}
1939
1940#[cfg(test)]
1941mod tests {
1942    use super::*;
1943    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1944    use assistant_settings::AssistantSettings;
1945    use context_server::ContextServerSettings;
1946    use editor::EditorSettings;
1947    use gpui::TestAppContext;
1948    use project::{FakeFs, Project};
1949    use prompt_store::PromptBuilder;
1950    use serde_json::json;
1951    use settings::{Settings, SettingsStore};
1952    use std::sync::Arc;
1953    use theme::ThemeSettings;
1954    use util::path;
1955    use workspace::Workspace;
1956
1957    #[gpui::test]
1958    async fn test_message_with_context(cx: &mut TestAppContext) {
1959        init_test_settings(cx);
1960
1961        let project = create_test_project(
1962            cx,
1963            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1964        )
1965        .await;
1966
1967        let (_workspace, _thread_store, thread, context_store) =
1968            setup_test_environment(cx, project.clone()).await;
1969
1970        add_file_to_context(&project, &context_store, "test/code.rs", cx)
1971            .await
1972            .unwrap();
1973
1974        let context =
1975            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1976
1977        // Insert user message with context
1978        let message_id = thread.update(cx, |thread, cx| {
1979            thread.insert_user_message("Please explain this code", vec![context], None, cx)
1980        });
1981
1982        // Check content and context in message object
1983        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
1984
1985        // Use different path format strings based on platform for the test
1986        #[cfg(windows)]
1987        let path_part = r"test\code.rs";
1988        #[cfg(not(windows))]
1989        let path_part = "test/code.rs";
1990
1991        let expected_context = format!(
1992            r#"
1993<context>
1994The following items were attached by the user. You don't need to use other tools to read them.
1995
1996<files>
1997```rs {path_part}
1998fn main() {{
1999    println!("Hello, world!");
2000}}
2001```
2002</files>
2003</context>
2004"#
2005        );
2006
2007        assert_eq!(message.role, Role::User);
2008        assert_eq!(message.segments.len(), 1);
2009        assert_eq!(
2010            message.segments[0],
2011            MessageSegment::Text("Please explain this code".to_string())
2012        );
2013        assert_eq!(message.context, expected_context);
2014
2015        // Check message in request
2016        let request = thread.read_with(cx, |thread, cx| {
2017            thread.to_completion_request(RequestKind::Chat, cx)
2018        });
2019
2020        assert_eq!(request.messages.len(), 1);
2021        let expected_full_message = format!("{}Please explain this code", expected_context);
2022        assert_eq!(request.messages[0].string_contents(), expected_full_message);
2023    }
2024
2025    #[gpui::test]
2026    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2027        init_test_settings(cx);
2028
2029        let project = create_test_project(
2030            cx,
2031            json!({
2032                "file1.rs": "fn function1() {}\n",
2033                "file2.rs": "fn function2() {}\n",
2034                "file3.rs": "fn function3() {}\n",
2035            }),
2036        )
2037        .await;
2038
2039        let (_, _thread_store, thread, context_store) =
2040            setup_test_environment(cx, project.clone()).await;
2041
2042        // Open files individually
2043        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2044            .await
2045            .unwrap();
2046        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2047            .await
2048            .unwrap();
2049        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2050            .await
2051            .unwrap();
2052
2053        // Get the context objects
2054        let contexts = context_store.update(cx, |store, _| store.context().clone());
2055        assert_eq!(contexts.len(), 3);
2056
2057        // First message with context 1
2058        let message1_id = thread.update(cx, |thread, cx| {
2059            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2060        });
2061
2062        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2063        let message2_id = thread.update(cx, |thread, cx| {
2064            thread.insert_user_message(
2065                "Message 2",
2066                vec![contexts[0].clone(), contexts[1].clone()],
2067                None,
2068                cx,
2069            )
2070        });
2071
2072        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2073        let message3_id = thread.update(cx, |thread, cx| {
2074            thread.insert_user_message(
2075                "Message 3",
2076                vec![
2077                    contexts[0].clone(),
2078                    contexts[1].clone(),
2079                    contexts[2].clone(),
2080                ],
2081                None,
2082                cx,
2083            )
2084        });
2085
2086        // Check what contexts are included in each message
2087        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2088            (
2089                thread.message(message1_id).unwrap().clone(),
2090                thread.message(message2_id).unwrap().clone(),
2091                thread.message(message3_id).unwrap().clone(),
2092            )
2093        });
2094
2095        // First message should include context 1
2096        assert!(message1.context.contains("file1.rs"));
2097
2098        // Second message should include only context 2 (not 1)
2099        assert!(!message2.context.contains("file1.rs"));
2100        assert!(message2.context.contains("file2.rs"));
2101
2102        // Third message should include only context 3 (not 1 or 2)
2103        assert!(!message3.context.contains("file1.rs"));
2104        assert!(!message3.context.contains("file2.rs"));
2105        assert!(message3.context.contains("file3.rs"));
2106
2107        // Check entire request to make sure all contexts are properly included
2108        let request = thread.read_with(cx, |thread, cx| {
2109            thread.to_completion_request(RequestKind::Chat, cx)
2110        });
2111
2112        // The request should contain all 3 messages
2113        assert_eq!(request.messages.len(), 3);
2114
2115        // Check that the contexts are properly formatted in each message
2116        assert!(request.messages[0].string_contents().contains("file1.rs"));
2117        assert!(!request.messages[0].string_contents().contains("file2.rs"));
2118        assert!(!request.messages[0].string_contents().contains("file3.rs"));
2119
2120        assert!(!request.messages[1].string_contents().contains("file1.rs"));
2121        assert!(request.messages[1].string_contents().contains("file2.rs"));
2122        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2123
2124        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2125        assert!(!request.messages[2].string_contents().contains("file2.rs"));
2126        assert!(request.messages[2].string_contents().contains("file3.rs"));
2127    }
2128
2129    #[gpui::test]
2130    async fn test_message_without_files(cx: &mut TestAppContext) {
2131        init_test_settings(cx);
2132
2133        let project = create_test_project(
2134            cx,
2135            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2136        )
2137        .await;
2138
2139        let (_, _thread_store, thread, _context_store) =
2140            setup_test_environment(cx, project.clone()).await;
2141
2142        // Insert user message without any context (empty context vector)
2143        let message_id = thread.update(cx, |thread, cx| {
2144            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2145        });
2146
2147        // Check content and context in message object
2148        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2149
2150        // Context should be empty when no files are included
2151        assert_eq!(message.role, Role::User);
2152        assert_eq!(message.segments.len(), 1);
2153        assert_eq!(
2154            message.segments[0],
2155            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2156        );
2157        assert_eq!(message.context, "");
2158
2159        // Check message in request
2160        let request = thread.read_with(cx, |thread, cx| {
2161            thread.to_completion_request(RequestKind::Chat, cx)
2162        });
2163
2164        assert_eq!(request.messages.len(), 1);
2165        assert_eq!(
2166            request.messages[0].string_contents(),
2167            "What is the best way to learn Rust?"
2168        );
2169
2170        // Add second message, also without context
2171        let message2_id = thread.update(cx, |thread, cx| {
2172            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2173        });
2174
2175        let message2 =
2176            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2177        assert_eq!(message2.context, "");
2178
2179        // Check that both messages appear in the request
2180        let request = thread.read_with(cx, |thread, cx| {
2181            thread.to_completion_request(RequestKind::Chat, cx)
2182        });
2183
2184        assert_eq!(request.messages.len(), 2);
2185        assert_eq!(
2186            request.messages[0].string_contents(),
2187            "What is the best way to learn Rust?"
2188        );
2189        assert_eq!(
2190            request.messages[1].string_contents(),
2191            "Are there any good books?"
2192        );
2193    }
2194
2195    #[gpui::test]
2196    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2197        init_test_settings(cx);
2198
2199        let project = create_test_project(
2200            cx,
2201            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2202        )
2203        .await;
2204
2205        let (_workspace, _thread_store, thread, context_store) =
2206            setup_test_environment(cx, project.clone()).await;
2207
2208        // Open buffer and add it to context
2209        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2210            .await
2211            .unwrap();
2212
2213        let context =
2214            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2215
2216        // Insert user message with the buffer as context
2217        thread.update(cx, |thread, cx| {
2218            thread.insert_user_message("Explain this code", vec![context], None, cx)
2219        });
2220
2221        // Create a request and check that it doesn't have a stale buffer warning yet
2222        let initial_request = thread.read_with(cx, |thread, cx| {
2223            thread.to_completion_request(RequestKind::Chat, cx)
2224        });
2225
2226        // Make sure we don't have a stale file warning yet
2227        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2228            msg.string_contents()
2229                .contains("These files changed since last read:")
2230        });
2231        assert!(
2232            !has_stale_warning,
2233            "Should not have stale buffer warning before buffer is modified"
2234        );
2235
2236        // Modify the buffer
2237        buffer.update(cx, |buffer, cx| {
2238            // Find a position at the end of line 1
2239            buffer.edit(
2240                [(1..1, "\n    println!(\"Added a new line\");\n")],
2241                None,
2242                cx,
2243            );
2244        });
2245
2246        // Insert another user message without context
2247        thread.update(cx, |thread, cx| {
2248            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2249        });
2250
2251        // Create a new request and check for the stale buffer warning
2252        let new_request = thread.read_with(cx, |thread, cx| {
2253            thread.to_completion_request(RequestKind::Chat, cx)
2254        });
2255
2256        // We should have a stale file warning as the last message
2257        let last_message = new_request
2258            .messages
2259            .last()
2260            .expect("Request should have messages");
2261
2262        // The last message should be the stale buffer notification
2263        assert_eq!(last_message.role, Role::User);
2264
2265        // Check the exact content of the message
2266        let expected_content = "These files changed since last read:\n- code.rs\n";
2267        assert_eq!(
2268            last_message.string_contents(),
2269            expected_content,
2270            "Last message should be exactly the stale buffer notification"
2271        );
2272    }
2273
2274    fn init_test_settings(cx: &mut TestAppContext) {
2275        cx.update(|cx| {
2276            let settings_store = SettingsStore::test(cx);
2277            cx.set_global(settings_store);
2278            language::init(cx);
2279            Project::init_settings(cx);
2280            AssistantSettings::register(cx);
2281            thread_store::init(cx);
2282            workspace::init_settings(cx);
2283            ThemeSettings::register(cx);
2284            ContextServerSettings::register(cx);
2285            EditorSettings::register(cx);
2286        });
2287    }
2288
2289    // Helper to create a test project with test files
2290    async fn create_test_project(
2291        cx: &mut TestAppContext,
2292        files: serde_json::Value,
2293    ) -> Entity<Project> {
2294        let fs = FakeFs::new(cx.executor());
2295        fs.insert_tree(path!("/test"), files).await;
2296        Project::test(fs, [path!("/test").as_ref()], cx).await
2297    }
2298
2299    async fn setup_test_environment(
2300        cx: &mut TestAppContext,
2301        project: Entity<Project>,
2302    ) -> (
2303        Entity<Workspace>,
2304        Entity<ThreadStore>,
2305        Entity<Thread>,
2306        Entity<ContextStore>,
2307    ) {
2308        let (workspace, cx) =
2309            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2310
2311        let thread_store = cx.update(|_, cx| {
2312            ThreadStore::new(
2313                project.clone(),
2314                Arc::default(),
2315                Arc::new(PromptBuilder::new(None).unwrap()),
2316                cx,
2317            )
2318            .unwrap()
2319        });
2320
2321        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2322        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2323
2324        (workspace, thread_store, thread, context_store)
2325    }
2326
2327    async fn add_file_to_context(
2328        project: &Entity<Project>,
2329        context_store: &Entity<ContextStore>,
2330        path: &str,
2331        cx: &mut TestAppContext,
2332    ) -> Result<Entity<language::Buffer>> {
2333        let buffer_path = project
2334            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2335            .unwrap();
2336
2337        let buffer = project
2338            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2339            .await
2340            .unwrap();
2341
2342        context_store
2343            .update(cx, |store, cx| {
2344                store.add_file_from_buffer(buffer.clone(), cx)
2345            })
2346            .await?;
2347
2348        Ok(buffer)
2349    }
2350}