thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  42
  43#[derive(Debug, Clone, Copy)]
  44pub enum RequestKind {
  45    Chat,
  46    /// Used when summarizing a thread.
  47    Summarize,
  48}
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  74pub struct MessageId(pub(crate) usize);
  75
  76impl MessageId {
  77    fn post_inc(&mut self) -> Self {
  78        Self(post_inc(&mut self.0))
  79    }
  80}
  81
  82/// A message in a [`Thread`].
  83#[derive(Debug, Clone)]
  84pub struct Message {
  85    pub id: MessageId,
  86    pub role: Role,
  87    pub segments: Vec<MessageSegment>,
  88    pub context: String,
  89}
  90
  91impl Message {
  92    /// Returns whether the message contains any meaningful text that should be displayed
  93    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  94    pub fn should_display_content(&self) -> bool {
  95        self.segments.iter().all(|segment| segment.should_display())
  96    }
  97
  98    pub fn push_thinking(&mut self, text: &str) {
  99        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
 100            segment.push_str(text);
 101        } else {
 102            self.segments
 103                .push(MessageSegment::Thinking(text.to_string()));
 104        }
 105    }
 106
 107    pub fn push_text(&mut self, text: &str) {
 108        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 109            segment.push_str(text);
 110        } else {
 111            self.segments.push(MessageSegment::Text(text.to_string()));
 112        }
 113    }
 114
 115    pub fn to_string(&self) -> String {
 116        let mut result = String::new();
 117
 118        if !self.context.is_empty() {
 119            result.push_str(&self.context);
 120        }
 121
 122        for segment in &self.segments {
 123            match segment {
 124                MessageSegment::Text(text) => result.push_str(text),
 125                MessageSegment::Thinking(text) => {
 126                    result.push_str("<think>");
 127                    result.push_str(text);
 128                    result.push_str("</think>");
 129                }
 130            }
 131        }
 132
 133        result
 134    }
 135}
 136
 137#[derive(Debug, Clone, PartialEq, Eq)]
 138pub enum MessageSegment {
 139    Text(String),
 140    Thinking(String),
 141}
 142
 143impl MessageSegment {
 144    pub fn text_mut(&mut self) -> &mut String {
 145        match self {
 146            Self::Text(text) => text,
 147            Self::Thinking(text) => text,
 148        }
 149    }
 150
 151    pub fn should_display(&self) -> bool {
 152        // We add USING_TOOL_MARKER when making a request that includes tool uses
 153        // without non-whitespace text around them, and this can cause the model
 154        // to mimic the pattern, so we consider those segments not displayable.
 155        match self {
 156            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 157            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 158        }
 159    }
 160}
 161
 162#[derive(Debug, Clone, Serialize, Deserialize)]
 163pub struct ProjectSnapshot {
 164    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 165    pub unsaved_buffer_paths: Vec<String>,
 166    pub timestamp: DateTime<Utc>,
 167}
 168
 169#[derive(Debug, Clone, Serialize, Deserialize)]
 170pub struct WorktreeSnapshot {
 171    pub worktree_path: String,
 172    pub git_state: Option<GitState>,
 173}
 174
 175#[derive(Debug, Clone, Serialize, Deserialize)]
 176pub struct GitState {
 177    pub remote_url: Option<String>,
 178    pub head_sha: Option<String>,
 179    pub current_branch: Option<String>,
 180    pub diff: Option<String>,
 181}
 182
 183#[derive(Clone)]
 184pub struct ThreadCheckpoint {
 185    message_id: MessageId,
 186    git_checkpoint: GitStoreCheckpoint,
 187}
 188
 189#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 190pub enum ThreadFeedback {
 191    Positive,
 192    Negative,
 193}
 194
 195pub enum LastRestoreCheckpoint {
 196    Pending {
 197        message_id: MessageId,
 198    },
 199    Error {
 200        message_id: MessageId,
 201        error: String,
 202    },
 203}
 204
 205impl LastRestoreCheckpoint {
 206    pub fn message_id(&self) -> MessageId {
 207        match self {
 208            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 209            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 210        }
 211    }
 212}
 213
 214#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 215pub enum DetailedSummaryState {
 216    #[default]
 217    NotGenerated,
 218    Generating {
 219        message_id: MessageId,
 220    },
 221    Generated {
 222        text: SharedString,
 223        message_id: MessageId,
 224    },
 225}
 226
 227#[derive(Default)]
 228pub struct TotalTokenUsage {
 229    pub total: usize,
 230    pub max: usize,
 231}
 232
 233impl TotalTokenUsage {
 234    pub fn ratio(&self) -> TokenUsageRatio {
 235        #[cfg(debug_assertions)]
 236        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 237            .unwrap_or("0.8".to_string())
 238            .parse()
 239            .unwrap();
 240        #[cfg(not(debug_assertions))]
 241        let warning_threshold: f32 = 0.8;
 242
 243        if self.total >= self.max {
 244            TokenUsageRatio::Exceeded
 245        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 246            TokenUsageRatio::Warning
 247        } else {
 248            TokenUsageRatio::Normal
 249        }
 250    }
 251
 252    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 253        TotalTokenUsage {
 254            total: self.total + tokens,
 255            max: self.max,
 256        }
 257    }
 258}
 259
 260#[derive(Debug, Default, PartialEq, Eq)]
 261pub enum TokenUsageRatio {
 262    #[default]
 263    Normal,
 264    Warning,
 265    Exceeded,
 266}
 267
 268/// A thread of conversation with the LLM.
 269pub struct Thread {
 270    id: ThreadId,
 271    updated_at: DateTime<Utc>,
 272    summary: Option<SharedString>,
 273    pending_summary: Task<Option<()>>,
 274    detailed_summary_state: DetailedSummaryState,
 275    messages: Vec<Message>,
 276    next_message_id: MessageId,
 277    context: BTreeMap<ContextId, AssistantContext>,
 278    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 279    project_context: SharedProjectContext,
 280    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 281    completion_count: usize,
 282    pending_completions: Vec<PendingCompletion>,
 283    project: Entity<Project>,
 284    prompt_builder: Arc<PromptBuilder>,
 285    tools: Entity<ToolWorkingSet>,
 286    tool_use: ToolUseState,
 287    action_log: Entity<ActionLog>,
 288    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 289    pending_checkpoint: Option<ThreadCheckpoint>,
 290    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 291    request_token_usage: Vec<TokenUsage>,
 292    cumulative_token_usage: TokenUsage,
 293    exceeded_window_error: Option<ExceededWindowError>,
 294    feedback: Option<ThreadFeedback>,
 295    message_feedback: HashMap<MessageId, ThreadFeedback>,
 296    last_auto_capture_at: Option<Instant>,
 297}
 298
 299#[derive(Debug, Clone, Serialize, Deserialize)]
 300pub struct ExceededWindowError {
 301    /// Model used when last message exceeded context window
 302    model_id: LanguageModelId,
 303    /// Token count including last message
 304    token_count: usize,
 305}
 306
 307impl Thread {
 308    pub fn new(
 309        project: Entity<Project>,
 310        tools: Entity<ToolWorkingSet>,
 311        prompt_builder: Arc<PromptBuilder>,
 312        system_prompt: SharedProjectContext,
 313        cx: &mut Context<Self>,
 314    ) -> Self {
 315        Self {
 316            id: ThreadId::new(),
 317            updated_at: Utc::now(),
 318            summary: None,
 319            pending_summary: Task::ready(None),
 320            detailed_summary_state: DetailedSummaryState::NotGenerated,
 321            messages: Vec::new(),
 322            next_message_id: MessageId(0),
 323            context: BTreeMap::default(),
 324            context_by_message: HashMap::default(),
 325            project_context: system_prompt,
 326            checkpoints_by_message: HashMap::default(),
 327            completion_count: 0,
 328            pending_completions: Vec::new(),
 329            project: project.clone(),
 330            prompt_builder,
 331            tools: tools.clone(),
 332            last_restore_checkpoint: None,
 333            pending_checkpoint: None,
 334            tool_use: ToolUseState::new(tools.clone()),
 335            action_log: cx.new(|_| ActionLog::new(project.clone())),
 336            initial_project_snapshot: {
 337                let project_snapshot = Self::project_snapshot(project, cx);
 338                cx.foreground_executor()
 339                    .spawn(async move { Some(project_snapshot.await) })
 340                    .shared()
 341            },
 342            request_token_usage: Vec::new(),
 343            cumulative_token_usage: TokenUsage::default(),
 344            exceeded_window_error: None,
 345            feedback: None,
 346            message_feedback: HashMap::default(),
 347            last_auto_capture_at: None,
 348        }
 349    }
 350
 351    pub fn deserialize(
 352        id: ThreadId,
 353        serialized: SerializedThread,
 354        project: Entity<Project>,
 355        tools: Entity<ToolWorkingSet>,
 356        prompt_builder: Arc<PromptBuilder>,
 357        project_context: SharedProjectContext,
 358        cx: &mut Context<Self>,
 359    ) -> Self {
 360        let next_message_id = MessageId(
 361            serialized
 362                .messages
 363                .last()
 364                .map(|message| message.id.0 + 1)
 365                .unwrap_or(0),
 366        );
 367        let tool_use =
 368            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 369
 370        Self {
 371            id,
 372            updated_at: serialized.updated_at,
 373            summary: Some(serialized.summary),
 374            pending_summary: Task::ready(None),
 375            detailed_summary_state: serialized.detailed_summary_state,
 376            messages: serialized
 377                .messages
 378                .into_iter()
 379                .map(|message| Message {
 380                    id: message.id,
 381                    role: message.role,
 382                    segments: message
 383                        .segments
 384                        .into_iter()
 385                        .map(|segment| match segment {
 386                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 387                            SerializedMessageSegment::Thinking { text } => {
 388                                MessageSegment::Thinking(text)
 389                            }
 390                        })
 391                        .collect(),
 392                    context: message.context,
 393                })
 394                .collect(),
 395            next_message_id,
 396            context: BTreeMap::default(),
 397            context_by_message: HashMap::default(),
 398            project_context,
 399            checkpoints_by_message: HashMap::default(),
 400            completion_count: 0,
 401            pending_completions: Vec::new(),
 402            last_restore_checkpoint: None,
 403            pending_checkpoint: None,
 404            project: project.clone(),
 405            prompt_builder,
 406            tools,
 407            tool_use,
 408            action_log: cx.new(|_| ActionLog::new(project)),
 409            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 410            request_token_usage: serialized.request_token_usage,
 411            cumulative_token_usage: serialized.cumulative_token_usage,
 412            exceeded_window_error: None,
 413            feedback: None,
 414            message_feedback: HashMap::default(),
 415            last_auto_capture_at: None,
 416        }
 417    }
 418
 419    pub fn id(&self) -> &ThreadId {
 420        &self.id
 421    }
 422
 423    pub fn is_empty(&self) -> bool {
 424        self.messages.is_empty()
 425    }
 426
 427    pub fn updated_at(&self) -> DateTime<Utc> {
 428        self.updated_at
 429    }
 430
 431    pub fn touch_updated_at(&mut self) {
 432        self.updated_at = Utc::now();
 433    }
 434
 435    pub fn summary(&self) -> Option<SharedString> {
 436        self.summary.clone()
 437    }
 438
 439    pub fn project_context(&self) -> SharedProjectContext {
 440        self.project_context.clone()
 441    }
 442
 443    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 444
 445    pub fn summary_or_default(&self) -> SharedString {
 446        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 447    }
 448
 449    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 450        let Some(current_summary) = &self.summary else {
 451            // Don't allow setting summary until generated
 452            return;
 453        };
 454
 455        let mut new_summary = new_summary.into();
 456
 457        if new_summary.is_empty() {
 458            new_summary = Self::DEFAULT_SUMMARY;
 459        }
 460
 461        if current_summary != &new_summary {
 462            self.summary = Some(new_summary);
 463            cx.emit(ThreadEvent::SummaryChanged);
 464        }
 465    }
 466
 467    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 468        self.latest_detailed_summary()
 469            .unwrap_or_else(|| self.text().into())
 470    }
 471
 472    fn latest_detailed_summary(&self) -> Option<SharedString> {
 473        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 474            Some(text.clone())
 475        } else {
 476            None
 477        }
 478    }
 479
 480    pub fn message(&self, id: MessageId) -> Option<&Message> {
 481        self.messages.iter().find(|message| message.id == id)
 482    }
 483
 484    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 485        self.messages.iter()
 486    }
 487
 488    pub fn is_generating(&self) -> bool {
 489        !self.pending_completions.is_empty() || !self.all_tools_finished()
 490    }
 491
 492    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 493        &self.tools
 494    }
 495
 496    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 497        self.tool_use
 498            .pending_tool_uses()
 499            .into_iter()
 500            .find(|tool_use| &tool_use.id == id)
 501    }
 502
 503    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 504        self.tool_use
 505            .pending_tool_uses()
 506            .into_iter()
 507            .filter(|tool_use| tool_use.status.needs_confirmation())
 508    }
 509
 510    pub fn has_pending_tool_uses(&self) -> bool {
 511        !self.tool_use.pending_tool_uses().is_empty()
 512    }
 513
 514    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 515        self.checkpoints_by_message.get(&id).cloned()
 516    }
 517
 518    pub fn restore_checkpoint(
 519        &mut self,
 520        checkpoint: ThreadCheckpoint,
 521        cx: &mut Context<Self>,
 522    ) -> Task<Result<()>> {
 523        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 524            message_id: checkpoint.message_id,
 525        });
 526        cx.emit(ThreadEvent::CheckpointChanged);
 527        cx.notify();
 528
 529        let git_store = self.project().read(cx).git_store().clone();
 530        let restore = git_store.update(cx, |git_store, cx| {
 531            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 532        });
 533
 534        cx.spawn(async move |this, cx| {
 535            let result = restore.await;
 536            this.update(cx, |this, cx| {
 537                if let Err(err) = result.as_ref() {
 538                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 539                        message_id: checkpoint.message_id,
 540                        error: err.to_string(),
 541                    });
 542                } else {
 543                    this.truncate(checkpoint.message_id, cx);
 544                    this.last_restore_checkpoint = None;
 545                }
 546                this.pending_checkpoint = None;
 547                cx.emit(ThreadEvent::CheckpointChanged);
 548                cx.notify();
 549            })?;
 550            result
 551        })
 552    }
 553
 554    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 555        let pending_checkpoint = if self.is_generating() {
 556            return;
 557        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 558            checkpoint
 559        } else {
 560            return;
 561        };
 562
 563        let git_store = self.project.read(cx).git_store().clone();
 564        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 565        cx.spawn(async move |this, cx| match final_checkpoint.await {
 566            Ok(final_checkpoint) => {
 567                let equal = git_store
 568                    .update(cx, |store, cx| {
 569                        store.compare_checkpoints(
 570                            pending_checkpoint.git_checkpoint.clone(),
 571                            final_checkpoint.clone(),
 572                            cx,
 573                        )
 574                    })?
 575                    .await
 576                    .unwrap_or(false);
 577
 578                if equal {
 579                    git_store
 580                        .update(cx, |store, cx| {
 581                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 582                        })?
 583                        .detach();
 584                } else {
 585                    this.update(cx, |this, cx| {
 586                        this.insert_checkpoint(pending_checkpoint, cx)
 587                    })?;
 588                }
 589
 590                git_store
 591                    .update(cx, |store, cx| {
 592                        store.delete_checkpoint(final_checkpoint, cx)
 593                    })?
 594                    .detach();
 595
 596                Ok(())
 597            }
 598            Err(_) => this.update(cx, |this, cx| {
 599                this.insert_checkpoint(pending_checkpoint, cx)
 600            }),
 601        })
 602        .detach();
 603    }
 604
 605    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 606        self.checkpoints_by_message
 607            .insert(checkpoint.message_id, checkpoint);
 608        cx.emit(ThreadEvent::CheckpointChanged);
 609        cx.notify();
 610    }
 611
 612    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 613        self.last_restore_checkpoint.as_ref()
 614    }
 615
 616    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 617        let Some(message_ix) = self
 618            .messages
 619            .iter()
 620            .rposition(|message| message.id == message_id)
 621        else {
 622            return;
 623        };
 624        for deleted_message in self.messages.drain(message_ix..) {
 625            self.context_by_message.remove(&deleted_message.id);
 626            self.checkpoints_by_message.remove(&deleted_message.id);
 627        }
 628        cx.notify();
 629    }
 630
 631    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 632        self.context_by_message
 633            .get(&id)
 634            .into_iter()
 635            .flat_map(|context| {
 636                context
 637                    .iter()
 638                    .filter_map(|context_id| self.context.get(&context_id))
 639            })
 640    }
 641
 642    /// Returns whether all of the tool uses have finished running.
 643    pub fn all_tools_finished(&self) -> bool {
 644        // If the only pending tool uses left are the ones with errors, then
 645        // that means that we've finished running all of the pending tools.
 646        self.tool_use
 647            .pending_tool_uses()
 648            .iter()
 649            .all(|tool_use| tool_use.status.is_error())
 650    }
 651
 652    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 653        self.tool_use.tool_uses_for_message(id, cx)
 654    }
 655
 656    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 657        self.tool_use.tool_results_for_message(id)
 658    }
 659
 660    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 661        self.tool_use.tool_result(id)
 662    }
 663
 664    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 665        Some(&self.tool_use.tool_result(id)?.content)
 666    }
 667
 668    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 669        self.tool_use.tool_result_card(id).cloned()
 670    }
 671
 672    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 673        self.tool_use.message_has_tool_results(message_id)
 674    }
 675
 676    /// Filter out contexts that have already been included in previous messages
 677    pub fn filter_new_context<'a>(
 678        &self,
 679        context: impl Iterator<Item = &'a AssistantContext>,
 680    ) -> impl Iterator<Item = &'a AssistantContext> {
 681        context.filter(|ctx| self.is_context_new(ctx))
 682    }
 683
 684    fn is_context_new(&self, context: &AssistantContext) -> bool {
 685        !self.context.contains_key(&context.id())
 686    }
 687
 688    pub fn insert_user_message(
 689        &mut self,
 690        text: impl Into<String>,
 691        context: Vec<AssistantContext>,
 692        git_checkpoint: Option<GitStoreCheckpoint>,
 693        cx: &mut Context<Self>,
 694    ) -> MessageId {
 695        let text = text.into();
 696
 697        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 698
 699        let new_context: Vec<_> = context
 700            .into_iter()
 701            .filter(|ctx| self.is_context_new(ctx))
 702            .collect();
 703
 704        if !new_context.is_empty() {
 705            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 706                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 707                    message.context = context_string;
 708                }
 709            }
 710
 711            self.action_log.update(cx, |log, cx| {
 712                // Track all buffers added as context
 713                for ctx in &new_context {
 714                    match ctx {
 715                        AssistantContext::File(file_ctx) => {
 716                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 717                        }
 718                        AssistantContext::Directory(dir_ctx) => {
 719                            for context_buffer in &dir_ctx.context_buffers {
 720                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 721                            }
 722                        }
 723                        AssistantContext::Symbol(symbol_ctx) => {
 724                            log.buffer_added_as_context(
 725                                symbol_ctx.context_symbol.buffer.clone(),
 726                                cx,
 727                            );
 728                        }
 729                        AssistantContext::Excerpt(excerpt_context) => {
 730                            log.buffer_added_as_context(
 731                                excerpt_context.context_buffer.buffer.clone(),
 732                                cx,
 733                            );
 734                        }
 735                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 736                    }
 737                }
 738            });
 739        }
 740
 741        let context_ids = new_context
 742            .iter()
 743            .map(|context| context.id())
 744            .collect::<Vec<_>>();
 745        self.context.extend(
 746            new_context
 747                .into_iter()
 748                .map(|context| (context.id(), context)),
 749        );
 750        self.context_by_message.insert(message_id, context_ids);
 751
 752        if let Some(git_checkpoint) = git_checkpoint {
 753            self.pending_checkpoint = Some(ThreadCheckpoint {
 754                message_id,
 755                git_checkpoint,
 756            });
 757        }
 758
 759        self.auto_capture_telemetry(cx);
 760
 761        message_id
 762    }
 763
 764    pub fn insert_message(
 765        &mut self,
 766        role: Role,
 767        segments: Vec<MessageSegment>,
 768        cx: &mut Context<Self>,
 769    ) -> MessageId {
 770        let id = self.next_message_id.post_inc();
 771        self.messages.push(Message {
 772            id,
 773            role,
 774            segments,
 775            context: String::new(),
 776        });
 777        self.touch_updated_at();
 778        cx.emit(ThreadEvent::MessageAdded(id));
 779        id
 780    }
 781
 782    pub fn edit_message(
 783        &mut self,
 784        id: MessageId,
 785        new_role: Role,
 786        new_segments: Vec<MessageSegment>,
 787        cx: &mut Context<Self>,
 788    ) -> bool {
 789        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 790            return false;
 791        };
 792        message.role = new_role;
 793        message.segments = new_segments;
 794        self.touch_updated_at();
 795        cx.emit(ThreadEvent::MessageEdited(id));
 796        true
 797    }
 798
 799    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 800        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 801            return false;
 802        };
 803        self.messages.remove(index);
 804        self.context_by_message.remove(&id);
 805        self.touch_updated_at();
 806        cx.emit(ThreadEvent::MessageDeleted(id));
 807        true
 808    }
 809
 810    /// Returns the representation of this [`Thread`] in a textual form.
 811    ///
 812    /// This is the representation we use when attaching a thread as context to another thread.
 813    pub fn text(&self) -> String {
 814        let mut text = String::new();
 815
 816        for message in &self.messages {
 817            text.push_str(match message.role {
 818                language_model::Role::User => "User:",
 819                language_model::Role::Assistant => "Assistant:",
 820                language_model::Role::System => "System:",
 821            });
 822            text.push('\n');
 823
 824            for segment in &message.segments {
 825                match segment {
 826                    MessageSegment::Text(content) => text.push_str(content),
 827                    MessageSegment::Thinking(content) => {
 828                        text.push_str(&format!("<think>{}</think>", content))
 829                    }
 830                }
 831            }
 832            text.push('\n');
 833        }
 834
 835        text
 836    }
 837
 838    /// Serializes this thread into a format for storage or telemetry.
 839    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 840        let initial_project_snapshot = self.initial_project_snapshot.clone();
 841        cx.spawn(async move |this, cx| {
 842            let initial_project_snapshot = initial_project_snapshot.await;
 843            this.read_with(cx, |this, cx| SerializedThread {
 844                version: SerializedThread::VERSION.to_string(),
 845                summary: this.summary_or_default(),
 846                updated_at: this.updated_at(),
 847                messages: this
 848                    .messages()
 849                    .map(|message| SerializedMessage {
 850                        id: message.id,
 851                        role: message.role,
 852                        segments: message
 853                            .segments
 854                            .iter()
 855                            .map(|segment| match segment {
 856                                MessageSegment::Text(text) => {
 857                                    SerializedMessageSegment::Text { text: text.clone() }
 858                                }
 859                                MessageSegment::Thinking(text) => {
 860                                    SerializedMessageSegment::Thinking { text: text.clone() }
 861                                }
 862                            })
 863                            .collect(),
 864                        tool_uses: this
 865                            .tool_uses_for_message(message.id, cx)
 866                            .into_iter()
 867                            .map(|tool_use| SerializedToolUse {
 868                                id: tool_use.id,
 869                                name: tool_use.name,
 870                                input: tool_use.input,
 871                            })
 872                            .collect(),
 873                        tool_results: this
 874                            .tool_results_for_message(message.id)
 875                            .into_iter()
 876                            .map(|tool_result| SerializedToolResult {
 877                                tool_use_id: tool_result.tool_use_id.clone(),
 878                                is_error: tool_result.is_error,
 879                                content: tool_result.content.clone(),
 880                            })
 881                            .collect(),
 882                        context: message.context.clone(),
 883                    })
 884                    .collect(),
 885                initial_project_snapshot,
 886                cumulative_token_usage: this.cumulative_token_usage,
 887                request_token_usage: this.request_token_usage.clone(),
 888                detailed_summary_state: this.detailed_summary_state.clone(),
 889                exceeded_window_error: this.exceeded_window_error.clone(),
 890            })
 891        })
 892    }
 893
 894    pub fn send_to_model(
 895        &mut self,
 896        model: Arc<dyn LanguageModel>,
 897        request_kind: RequestKind,
 898        cx: &mut Context<Self>,
 899    ) {
 900        let mut request = self.to_completion_request(request_kind, cx);
 901        if model.supports_tools() {
 902            request.tools = {
 903                let mut tools = Vec::new();
 904                tools.extend(
 905                    self.tools()
 906                        .read(cx)
 907                        .enabled_tools(cx)
 908                        .into_iter()
 909                        .filter_map(|tool| {
 910                            // Skip tools that cannot be supported
 911                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 912                            Some(LanguageModelRequestTool {
 913                                name: tool.name(),
 914                                description: tool.description(),
 915                                input_schema,
 916                            })
 917                        }),
 918                );
 919
 920                tools
 921            };
 922        }
 923
 924        self.stream_completion(request, model, cx);
 925    }
 926
 927    pub fn used_tools_since_last_user_message(&self) -> bool {
 928        for message in self.messages.iter().rev() {
 929            if self.tool_use.message_has_tool_results(message.id) {
 930                return true;
 931            } else if message.role == Role::User {
 932                return false;
 933            }
 934        }
 935
 936        false
 937    }
 938
 939    pub fn to_completion_request(
 940        &self,
 941        request_kind: RequestKind,
 942        cx: &App,
 943    ) -> LanguageModelRequest {
 944        let mut request = LanguageModelRequest {
 945            messages: vec![],
 946            tools: Vec::new(),
 947            stop: Vec::new(),
 948            temperature: None,
 949        };
 950
 951        if let Some(project_context) = self.project_context.borrow().as_ref() {
 952            if let Some(system_prompt) = self
 953                .prompt_builder
 954                .generate_assistant_system_prompt(project_context)
 955                .context("failed to generate assistant system prompt")
 956                .log_err()
 957            {
 958                request.messages.push(LanguageModelRequestMessage {
 959                    role: Role::System,
 960                    content: vec![MessageContent::Text(system_prompt)],
 961                    cache: true,
 962                });
 963            }
 964        } else {
 965            log::error!("project_context not set.")
 966        }
 967
 968        for message in &self.messages {
 969            let mut request_message = LanguageModelRequestMessage {
 970                role: message.role,
 971                content: Vec::new(),
 972                cache: false,
 973            };
 974
 975            match request_kind {
 976                RequestKind::Chat => {
 977                    self.tool_use
 978                        .attach_tool_results(message.id, &mut request_message);
 979                }
 980                RequestKind::Summarize => {
 981                    // We don't care about tool use during summarization.
 982                    if self.tool_use.message_has_tool_results(message.id) {
 983                        continue;
 984                    }
 985                }
 986            }
 987
 988            if !message.segments.is_empty() {
 989                request_message
 990                    .content
 991                    .push(MessageContent::Text(message.to_string()));
 992            }
 993
 994            match request_kind {
 995                RequestKind::Chat => {
 996                    self.tool_use
 997                        .attach_tool_uses(message.id, &mut request_message);
 998                }
 999                RequestKind::Summarize => {
1000                    // We don't care about tool use during summarization.
1001                }
1002            };
1003
1004            request.messages.push(request_message);
1005        }
1006
1007        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1008        if let Some(last) = request.messages.last_mut() {
1009            last.cache = true;
1010        }
1011
1012        self.attached_tracked_files_state(&mut request.messages, cx);
1013
1014        request
1015    }
1016
1017    fn attached_tracked_files_state(
1018        &self,
1019        messages: &mut Vec<LanguageModelRequestMessage>,
1020        cx: &App,
1021    ) {
1022        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1023
1024        let mut stale_message = String::new();
1025
1026        let action_log = self.action_log.read(cx);
1027
1028        for stale_file in action_log.stale_buffers(cx) {
1029            let Some(file) = stale_file.read(cx).file() else {
1030                continue;
1031            };
1032
1033            if stale_message.is_empty() {
1034                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1035            }
1036
1037            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1038        }
1039
1040        let mut content = Vec::with_capacity(2);
1041
1042        if !stale_message.is_empty() {
1043            content.push(stale_message.into());
1044        }
1045
1046        if action_log.has_edited_files_since_project_diagnostics_check() {
1047            content.push(
1048                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1049                and fix all errors AND warnings you introduced! \
1050                DO NOT mention you're going to do this until you're done."
1051                    .into(),
1052            );
1053        }
1054
1055        if !content.is_empty() {
1056            let context_message = LanguageModelRequestMessage {
1057                role: Role::User,
1058                content,
1059                cache: false,
1060            };
1061
1062            messages.push(context_message);
1063        }
1064    }
1065
1066    pub fn stream_completion(
1067        &mut self,
1068        request: LanguageModelRequest,
1069        model: Arc<dyn LanguageModel>,
1070        cx: &mut Context<Self>,
1071    ) {
1072        let pending_completion_id = post_inc(&mut self.completion_count);
1073        let task = cx.spawn(async move |thread, cx| {
1074            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1075            let initial_token_usage =
1076                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1077            let stream_completion = async {
1078                let (mut events, usage) = stream_completion_future.await?;
1079                let mut stop_reason = StopReason::EndTurn;
1080                let mut current_token_usage = TokenUsage::default();
1081
1082                if let Some(usage) = usage {
1083                    thread
1084                        .update(cx, |_thread, cx| {
1085                            cx.emit(ThreadEvent::UsageUpdated(usage));
1086                        })
1087                        .ok();
1088                }
1089
1090                while let Some(event) = events.next().await {
1091                    let event = event?;
1092
1093                    thread.update(cx, |thread, cx| {
1094                        match event {
1095                            LanguageModelCompletionEvent::StartMessage { .. } => {
1096                                thread.insert_message(
1097                                    Role::Assistant,
1098                                    vec![MessageSegment::Text(String::new())],
1099                                    cx,
1100                                );
1101                            }
1102                            LanguageModelCompletionEvent::Stop(reason) => {
1103                                stop_reason = reason;
1104                            }
1105                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1106                                thread.update_token_usage_at_last_message(token_usage);
1107                                thread.cumulative_token_usage = thread.cumulative_token_usage
1108                                    + token_usage
1109                                    - current_token_usage;
1110                                current_token_usage = token_usage;
1111                            }
1112                            LanguageModelCompletionEvent::Text(chunk) => {
1113                                if let Some(last_message) = thread.messages.last_mut() {
1114                                    if last_message.role == Role::Assistant {
1115                                        last_message.push_text(&chunk);
1116                                        cx.emit(ThreadEvent::StreamedAssistantText(
1117                                            last_message.id,
1118                                            chunk,
1119                                        ));
1120                                    } else {
1121                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1122                                        // of a new Assistant response.
1123                                        //
1124                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1125                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1126                                        thread.insert_message(
1127                                            Role::Assistant,
1128                                            vec![MessageSegment::Text(chunk.to_string())],
1129                                            cx,
1130                                        );
1131                                    };
1132                                }
1133                            }
1134                            LanguageModelCompletionEvent::Thinking(chunk) => {
1135                                if let Some(last_message) = thread.messages.last_mut() {
1136                                    if last_message.role == Role::Assistant {
1137                                        last_message.push_thinking(&chunk);
1138                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1139                                            last_message.id,
1140                                            chunk,
1141                                        ));
1142                                    } else {
1143                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1144                                        // of a new Assistant response.
1145                                        //
1146                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1147                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1148                                        thread.insert_message(
1149                                            Role::Assistant,
1150                                            vec![MessageSegment::Thinking(chunk.to_string())],
1151                                            cx,
1152                                        );
1153                                    };
1154                                }
1155                            }
1156                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1157                                let last_assistant_message_id = thread
1158                                    .messages
1159                                    .iter_mut()
1160                                    .rfind(|message| message.role == Role::Assistant)
1161                                    .map(|message| message.id)
1162                                    .unwrap_or_else(|| {
1163                                        thread.insert_message(Role::Assistant, vec![], cx)
1164                                    });
1165
1166                                thread.tool_use.request_tool_use(
1167                                    last_assistant_message_id,
1168                                    tool_use,
1169                                    cx,
1170                                );
1171                            }
1172                        }
1173
1174                        thread.touch_updated_at();
1175                        cx.emit(ThreadEvent::StreamedCompletion);
1176                        cx.notify();
1177
1178                        thread.auto_capture_telemetry(cx);
1179                    })?;
1180
1181                    smol::future::yield_now().await;
1182                }
1183
1184                thread.update(cx, |thread, cx| {
1185                    thread
1186                        .pending_completions
1187                        .retain(|completion| completion.id != pending_completion_id);
1188
1189                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1190                        thread.summarize(cx);
1191                    }
1192                })?;
1193
1194                anyhow::Ok(stop_reason)
1195            };
1196
1197            let result = stream_completion.await;
1198
1199            thread
1200                .update(cx, |thread, cx| {
1201                    thread.finalize_pending_checkpoint(cx);
1202                    match result.as_ref() {
1203                        Ok(stop_reason) => match stop_reason {
1204                            StopReason::ToolUse => {
1205                                let tool_uses = thread.use_pending_tools(cx);
1206                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1207                            }
1208                            StopReason::EndTurn => {}
1209                            StopReason::MaxTokens => {}
1210                        },
1211                        Err(error) => {
1212                            if error.is::<PaymentRequiredError>() {
1213                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1214                            } else if error.is::<MaxMonthlySpendReachedError>() {
1215                                cx.emit(ThreadEvent::ShowError(
1216                                    ThreadError::MaxMonthlySpendReached,
1217                                ));
1218                            } else if let Some(error) =
1219                                error.downcast_ref::<ModelRequestLimitReachedError>()
1220                            {
1221                                cx.emit(ThreadEvent::ShowError(
1222                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1223                                ));
1224                            } else if let Some(known_error) =
1225                                error.downcast_ref::<LanguageModelKnownError>()
1226                            {
1227                                match known_error {
1228                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1229                                        tokens,
1230                                    } => {
1231                                        thread.exceeded_window_error = Some(ExceededWindowError {
1232                                            model_id: model.id(),
1233                                            token_count: *tokens,
1234                                        });
1235                                        cx.notify();
1236                                    }
1237                                }
1238                            } else {
1239                                let error_message = error
1240                                    .chain()
1241                                    .map(|err| err.to_string())
1242                                    .collect::<Vec<_>>()
1243                                    .join("\n");
1244                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1245                                    header: "Error interacting with language model".into(),
1246                                    message: SharedString::from(error_message.clone()),
1247                                }));
1248                            }
1249
1250                            thread.cancel_last_completion(cx);
1251                        }
1252                    }
1253                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1254
1255                    thread.auto_capture_telemetry(cx);
1256
1257                    if let Ok(initial_usage) = initial_token_usage {
1258                        let usage = thread.cumulative_token_usage - initial_usage;
1259
1260                        telemetry::event!(
1261                            "Assistant Thread Completion",
1262                            thread_id = thread.id().to_string(),
1263                            model = model.telemetry_id(),
1264                            model_provider = model.provider_id().to_string(),
1265                            input_tokens = usage.input_tokens,
1266                            output_tokens = usage.output_tokens,
1267                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1268                            cache_read_input_tokens = usage.cache_read_input_tokens,
1269                        );
1270                    }
1271                })
1272                .ok();
1273        });
1274
1275        self.pending_completions.push(PendingCompletion {
1276            id: pending_completion_id,
1277            _task: task,
1278        });
1279    }
1280
1281    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1282        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1283            return;
1284        };
1285
1286        if !model.provider.is_authenticated(cx) {
1287            return;
1288        }
1289
1290        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1291        request.messages.push(LanguageModelRequestMessage {
1292            role: Role::User,
1293            content: vec![
1294                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1295                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1296                 If the conversation is about a specific subject, include it in the title. \
1297                 Be descriptive. DO NOT speak in the first person."
1298                    .into(),
1299            ],
1300            cache: false,
1301        });
1302
1303        self.pending_summary = cx.spawn(async move |this, cx| {
1304            async move {
1305                let stream = model.model.stream_completion_text(request, &cx);
1306                let mut messages = stream.await?;
1307
1308                let mut new_summary = String::new();
1309                while let Some(message) = messages.stream.next().await {
1310                    let text = message?;
1311                    let mut lines = text.lines();
1312                    new_summary.extend(lines.next());
1313
1314                    // Stop if the LLM generated multiple lines.
1315                    if lines.next().is_some() {
1316                        break;
1317                    }
1318                }
1319
1320                this.update(cx, |this, cx| {
1321                    if !new_summary.is_empty() {
1322                        this.summary = Some(new_summary.into());
1323                    }
1324
1325                    cx.emit(ThreadEvent::SummaryGenerated);
1326                })?;
1327
1328                anyhow::Ok(())
1329            }
1330            .log_err()
1331            .await
1332        });
1333    }
1334
1335    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1336        let last_message_id = self.messages.last().map(|message| message.id)?;
1337
1338        match &self.detailed_summary_state {
1339            DetailedSummaryState::Generating { message_id, .. }
1340            | DetailedSummaryState::Generated { message_id, .. }
1341                if *message_id == last_message_id =>
1342            {
1343                // Already up-to-date
1344                return None;
1345            }
1346            _ => {}
1347        }
1348
1349        let ConfiguredModel { model, provider } =
1350            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1351
1352        if !provider.is_authenticated(cx) {
1353            return None;
1354        }
1355
1356        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1357
1358        request.messages.push(LanguageModelRequestMessage {
1359            role: Role::User,
1360            content: vec![
1361                "Generate a detailed summary of this conversation. Include:\n\
1362                1. A brief overview of what was discussed\n\
1363                2. Key facts or information discovered\n\
1364                3. Outcomes or conclusions reached\n\
1365                4. Any action items or next steps if any\n\
1366                Format it in Markdown with headings and bullet points."
1367                    .into(),
1368            ],
1369            cache: false,
1370        });
1371
1372        let task = cx.spawn(async move |thread, cx| {
1373            let stream = model.stream_completion_text(request, &cx);
1374            let Some(mut messages) = stream.await.log_err() else {
1375                thread
1376                    .update(cx, |this, _cx| {
1377                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1378                    })
1379                    .log_err();
1380
1381                return;
1382            };
1383
1384            let mut new_detailed_summary = String::new();
1385
1386            while let Some(chunk) = messages.stream.next().await {
1387                if let Some(chunk) = chunk.log_err() {
1388                    new_detailed_summary.push_str(&chunk);
1389                }
1390            }
1391
1392            thread
1393                .update(cx, |this, _cx| {
1394                    this.detailed_summary_state = DetailedSummaryState::Generated {
1395                        text: new_detailed_summary.into(),
1396                        message_id: last_message_id,
1397                    };
1398                })
1399                .log_err();
1400        });
1401
1402        self.detailed_summary_state = DetailedSummaryState::Generating {
1403            message_id: last_message_id,
1404        };
1405
1406        Some(task)
1407    }
1408
1409    pub fn is_generating_detailed_summary(&self) -> bool {
1410        matches!(
1411            self.detailed_summary_state,
1412            DetailedSummaryState::Generating { .. }
1413        )
1414    }
1415
1416    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1417        self.auto_capture_telemetry(cx);
1418        let request = self.to_completion_request(RequestKind::Chat, cx);
1419        let messages = Arc::new(request.messages);
1420        let pending_tool_uses = self
1421            .tool_use
1422            .pending_tool_uses()
1423            .into_iter()
1424            .filter(|tool_use| tool_use.status.is_idle())
1425            .cloned()
1426            .collect::<Vec<_>>();
1427
1428        for tool_use in pending_tool_uses.iter() {
1429            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1430                if tool.needs_confirmation(&tool_use.input, cx)
1431                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1432                {
1433                    self.tool_use.confirm_tool_use(
1434                        tool_use.id.clone(),
1435                        tool_use.ui_text.clone(),
1436                        tool_use.input.clone(),
1437                        messages.clone(),
1438                        tool,
1439                    );
1440                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1441                } else {
1442                    self.run_tool(
1443                        tool_use.id.clone(),
1444                        tool_use.ui_text.clone(),
1445                        tool_use.input.clone(),
1446                        &messages,
1447                        tool,
1448                        cx,
1449                    );
1450                }
1451            }
1452        }
1453
1454        pending_tool_uses
1455    }
1456
1457    pub fn run_tool(
1458        &mut self,
1459        tool_use_id: LanguageModelToolUseId,
1460        ui_text: impl Into<SharedString>,
1461        input: serde_json::Value,
1462        messages: &[LanguageModelRequestMessage],
1463        tool: Arc<dyn Tool>,
1464        cx: &mut Context<Thread>,
1465    ) {
1466        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1467        self.tool_use
1468            .run_pending_tool(tool_use_id, ui_text.into(), task);
1469    }
1470
1471    fn spawn_tool_use(
1472        &mut self,
1473        tool_use_id: LanguageModelToolUseId,
1474        messages: &[LanguageModelRequestMessage],
1475        input: serde_json::Value,
1476        tool: Arc<dyn Tool>,
1477        cx: &mut Context<Thread>,
1478    ) -> Task<()> {
1479        let tool_name: Arc<str> = tool.name().into();
1480
1481        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1482            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1483        } else {
1484            tool.run(
1485                input,
1486                messages,
1487                self.project.clone(),
1488                self.action_log.clone(),
1489                cx,
1490            )
1491        };
1492
1493        // Store the card separately if it exists
1494        if let Some(card) = tool_result.card.clone() {
1495            self.tool_use
1496                .insert_tool_result_card(tool_use_id.clone(), card);
1497        }
1498
1499        cx.spawn({
1500            async move |thread: WeakEntity<Thread>, cx| {
1501                let output = tool_result.output.await;
1502
1503                thread
1504                    .update(cx, |thread, cx| {
1505                        let pending_tool_use = thread.tool_use.insert_tool_output(
1506                            tool_use_id.clone(),
1507                            tool_name,
1508                            output,
1509                            cx,
1510                        );
1511                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1512                    })
1513                    .ok();
1514            }
1515        })
1516    }
1517
1518    fn tool_finished(
1519        &mut self,
1520        tool_use_id: LanguageModelToolUseId,
1521        pending_tool_use: Option<PendingToolUse>,
1522        canceled: bool,
1523        cx: &mut Context<Self>,
1524    ) {
1525        if self.all_tools_finished() {
1526            let model_registry = LanguageModelRegistry::read_global(cx);
1527            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1528                self.attach_tool_results(cx);
1529                if !canceled {
1530                    self.send_to_model(model, RequestKind::Chat, cx);
1531                }
1532            }
1533        }
1534
1535        cx.emit(ThreadEvent::ToolFinished {
1536            tool_use_id,
1537            pending_tool_use,
1538        });
1539    }
1540
1541    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1542        // Insert a user message to contain the tool results.
1543        self.insert_user_message(
1544            // TODO: Sending up a user message without any content results in the model sending back
1545            // responses that also don't have any content. We currently don't handle this case well,
1546            // so for now we provide some text to keep the model on track.
1547            "Here are the tool results.",
1548            Vec::new(),
1549            None,
1550            cx,
1551        );
1552    }
1553
1554    /// Cancels the last pending completion, if there are any pending.
1555    ///
1556    /// Returns whether a completion was canceled.
1557    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1558        let canceled = if self.pending_completions.pop().is_some() {
1559            true
1560        } else {
1561            let mut canceled = false;
1562            for pending_tool_use in self.tool_use.cancel_pending() {
1563                canceled = true;
1564                self.tool_finished(
1565                    pending_tool_use.id.clone(),
1566                    Some(pending_tool_use),
1567                    true,
1568                    cx,
1569                );
1570            }
1571            canceled
1572        };
1573        self.finalize_pending_checkpoint(cx);
1574        canceled
1575    }
1576
1577    pub fn feedback(&self) -> Option<ThreadFeedback> {
1578        self.feedback
1579    }
1580
1581    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1582        self.message_feedback.get(&message_id).copied()
1583    }
1584
1585    pub fn report_message_feedback(
1586        &mut self,
1587        message_id: MessageId,
1588        feedback: ThreadFeedback,
1589        cx: &mut Context<Self>,
1590    ) -> Task<Result<()>> {
1591        if self.message_feedback.get(&message_id) == Some(&feedback) {
1592            return Task::ready(Ok(()));
1593        }
1594
1595        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1596        let serialized_thread = self.serialize(cx);
1597        let thread_id = self.id().clone();
1598        let client = self.project.read(cx).client();
1599
1600        let enabled_tool_names: Vec<String> = self
1601            .tools()
1602            .read(cx)
1603            .enabled_tools(cx)
1604            .iter()
1605            .map(|tool| tool.name().to_string())
1606            .collect();
1607
1608        self.message_feedback.insert(message_id, feedback);
1609
1610        cx.notify();
1611
1612        let message_content = self
1613            .message(message_id)
1614            .map(|msg| msg.to_string())
1615            .unwrap_or_default();
1616
1617        cx.background_spawn(async move {
1618            let final_project_snapshot = final_project_snapshot.await;
1619            let serialized_thread = serialized_thread.await?;
1620            let thread_data =
1621                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1622
1623            let rating = match feedback {
1624                ThreadFeedback::Positive => "positive",
1625                ThreadFeedback::Negative => "negative",
1626            };
1627            telemetry::event!(
1628                "Assistant Thread Rated",
1629                rating,
1630                thread_id,
1631                enabled_tool_names,
1632                message_id = message_id.0,
1633                message_content,
1634                thread_data,
1635                final_project_snapshot
1636            );
1637            client.telemetry().flush_events();
1638
1639            Ok(())
1640        })
1641    }
1642
1643    pub fn report_feedback(
1644        &mut self,
1645        feedback: ThreadFeedback,
1646        cx: &mut Context<Self>,
1647    ) -> Task<Result<()>> {
1648        let last_assistant_message_id = self
1649            .messages
1650            .iter()
1651            .rev()
1652            .find(|msg| msg.role == Role::Assistant)
1653            .map(|msg| msg.id);
1654
1655        if let Some(message_id) = last_assistant_message_id {
1656            self.report_message_feedback(message_id, feedback, cx)
1657        } else {
1658            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1659            let serialized_thread = self.serialize(cx);
1660            let thread_id = self.id().clone();
1661            let client = self.project.read(cx).client();
1662            self.feedback = Some(feedback);
1663            cx.notify();
1664
1665            cx.background_spawn(async move {
1666                let final_project_snapshot = final_project_snapshot.await;
1667                let serialized_thread = serialized_thread.await?;
1668                let thread_data = serde_json::to_value(serialized_thread)
1669                    .unwrap_or_else(|_| serde_json::Value::Null);
1670
1671                let rating = match feedback {
1672                    ThreadFeedback::Positive => "positive",
1673                    ThreadFeedback::Negative => "negative",
1674                };
1675                telemetry::event!(
1676                    "Assistant Thread Rated",
1677                    rating,
1678                    thread_id,
1679                    thread_data,
1680                    final_project_snapshot
1681                );
1682                client.telemetry().flush_events();
1683
1684                Ok(())
1685            })
1686        }
1687    }
1688
1689    /// Create a snapshot of the current project state including git information and unsaved buffers.
1690    fn project_snapshot(
1691        project: Entity<Project>,
1692        cx: &mut Context<Self>,
1693    ) -> Task<Arc<ProjectSnapshot>> {
1694        let git_store = project.read(cx).git_store().clone();
1695        let worktree_snapshots: Vec<_> = project
1696            .read(cx)
1697            .visible_worktrees(cx)
1698            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1699            .collect();
1700
1701        cx.spawn(async move |_, cx| {
1702            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1703
1704            let mut unsaved_buffers = Vec::new();
1705            cx.update(|app_cx| {
1706                let buffer_store = project.read(app_cx).buffer_store();
1707                for buffer_handle in buffer_store.read(app_cx).buffers() {
1708                    let buffer = buffer_handle.read(app_cx);
1709                    if buffer.is_dirty() {
1710                        if let Some(file) = buffer.file() {
1711                            let path = file.path().to_string_lossy().to_string();
1712                            unsaved_buffers.push(path);
1713                        }
1714                    }
1715                }
1716            })
1717            .ok();
1718
1719            Arc::new(ProjectSnapshot {
1720                worktree_snapshots,
1721                unsaved_buffer_paths: unsaved_buffers,
1722                timestamp: Utc::now(),
1723            })
1724        })
1725    }
1726
1727    fn worktree_snapshot(
1728        worktree: Entity<project::Worktree>,
1729        git_store: Entity<GitStore>,
1730        cx: &App,
1731    ) -> Task<WorktreeSnapshot> {
1732        cx.spawn(async move |cx| {
1733            // Get worktree path and snapshot
1734            let worktree_info = cx.update(|app_cx| {
1735                let worktree = worktree.read(app_cx);
1736                let path = worktree.abs_path().to_string_lossy().to_string();
1737                let snapshot = worktree.snapshot();
1738                (path, snapshot)
1739            });
1740
1741            let Ok((worktree_path, _snapshot)) = worktree_info else {
1742                return WorktreeSnapshot {
1743                    worktree_path: String::new(),
1744                    git_state: None,
1745                };
1746            };
1747
1748            let git_state = git_store
1749                .update(cx, |git_store, cx| {
1750                    git_store
1751                        .repositories()
1752                        .values()
1753                        .find(|repo| {
1754                            repo.read(cx)
1755                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1756                                .is_some()
1757                        })
1758                        .cloned()
1759                })
1760                .ok()
1761                .flatten()
1762                .map(|repo| {
1763                    repo.update(cx, |repo, _| {
1764                        let current_branch =
1765                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1766                        repo.send_job(None, |state, _| async move {
1767                            let RepositoryState::Local { backend, .. } = state else {
1768                                return GitState {
1769                                    remote_url: None,
1770                                    head_sha: None,
1771                                    current_branch,
1772                                    diff: None,
1773                                };
1774                            };
1775
1776                            let remote_url = backend.remote_url("origin");
1777                            let head_sha = backend.head_sha();
1778                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1779
1780                            GitState {
1781                                remote_url,
1782                                head_sha,
1783                                current_branch,
1784                                diff,
1785                            }
1786                        })
1787                    })
1788                });
1789
1790            let git_state = match git_state {
1791                Some(git_state) => match git_state.ok() {
1792                    Some(git_state) => git_state.await.ok(),
1793                    None => None,
1794                },
1795                None => None,
1796            };
1797
1798            WorktreeSnapshot {
1799                worktree_path,
1800                git_state,
1801            }
1802        })
1803    }
1804
1805    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1806        let mut markdown = Vec::new();
1807
1808        if let Some(summary) = self.summary() {
1809            writeln!(markdown, "# {summary}\n")?;
1810        };
1811
1812        for message in self.messages() {
1813            writeln!(
1814                markdown,
1815                "## {role}\n",
1816                role = match message.role {
1817                    Role::User => "User",
1818                    Role::Assistant => "Assistant",
1819                    Role::System => "System",
1820                }
1821            )?;
1822
1823            if !message.context.is_empty() {
1824                writeln!(markdown, "{}", message.context)?;
1825            }
1826
1827            for segment in &message.segments {
1828                match segment {
1829                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1830                    MessageSegment::Thinking(text) => {
1831                        writeln!(markdown, "<think>{}</think>\n", text)?
1832                    }
1833                }
1834            }
1835
1836            for tool_use in self.tool_uses_for_message(message.id, cx) {
1837                writeln!(
1838                    markdown,
1839                    "**Use Tool: {} ({})**",
1840                    tool_use.name, tool_use.id
1841                )?;
1842                writeln!(markdown, "```json")?;
1843                writeln!(
1844                    markdown,
1845                    "{}",
1846                    serde_json::to_string_pretty(&tool_use.input)?
1847                )?;
1848                writeln!(markdown, "```")?;
1849            }
1850
1851            for tool_result in self.tool_results_for_message(message.id) {
1852                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1853                if tool_result.is_error {
1854                    write!(markdown, " (Error)")?;
1855                }
1856
1857                writeln!(markdown, "**\n")?;
1858                writeln!(markdown, "{}", tool_result.content)?;
1859            }
1860        }
1861
1862        Ok(String::from_utf8_lossy(&markdown).to_string())
1863    }
1864
1865    pub fn keep_edits_in_range(
1866        &mut self,
1867        buffer: Entity<language::Buffer>,
1868        buffer_range: Range<language::Anchor>,
1869        cx: &mut Context<Self>,
1870    ) {
1871        self.action_log.update(cx, |action_log, cx| {
1872            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1873        });
1874    }
1875
1876    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1877        self.action_log
1878            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1879    }
1880
1881    pub fn reject_edits_in_ranges(
1882        &mut self,
1883        buffer: Entity<language::Buffer>,
1884        buffer_ranges: Vec<Range<language::Anchor>>,
1885        cx: &mut Context<Self>,
1886    ) -> Task<Result<()>> {
1887        self.action_log.update(cx, |action_log, cx| {
1888            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1889        })
1890    }
1891
1892    pub fn action_log(&self) -> &Entity<ActionLog> {
1893        &self.action_log
1894    }
1895
1896    pub fn project(&self) -> &Entity<Project> {
1897        &self.project
1898    }
1899
1900    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1901        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1902            return;
1903        }
1904
1905        let now = Instant::now();
1906        if let Some(last) = self.last_auto_capture_at {
1907            if now.duration_since(last).as_secs() < 10 {
1908                return;
1909            }
1910        }
1911
1912        self.last_auto_capture_at = Some(now);
1913
1914        let thread_id = self.id().clone();
1915        let github_login = self
1916            .project
1917            .read(cx)
1918            .user_store()
1919            .read(cx)
1920            .current_user()
1921            .map(|user| user.github_login.clone());
1922        let client = self.project.read(cx).client().clone();
1923        let serialize_task = self.serialize(cx);
1924
1925        cx.background_executor()
1926            .spawn(async move {
1927                if let Ok(serialized_thread) = serialize_task.await {
1928                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1929                        telemetry::event!(
1930                            "Agent Thread Auto-Captured",
1931                            thread_id = thread_id.to_string(),
1932                            thread_data = thread_data,
1933                            auto_capture_reason = "tracked_user",
1934                            github_login = github_login
1935                        );
1936
1937                        client.telemetry().flush_events();
1938                    }
1939                }
1940            })
1941            .detach();
1942    }
1943
1944    pub fn cumulative_token_usage(&self) -> TokenUsage {
1945        self.cumulative_token_usage
1946    }
1947
1948    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1949        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1950            return TotalTokenUsage::default();
1951        };
1952
1953        let max = model.model.max_token_count();
1954
1955        let index = self
1956            .messages
1957            .iter()
1958            .position(|msg| msg.id == message_id)
1959            .unwrap_or(0);
1960
1961        if index == 0 {
1962            return TotalTokenUsage { total: 0, max };
1963        }
1964
1965        let token_usage = &self
1966            .request_token_usage
1967            .get(index - 1)
1968            .cloned()
1969            .unwrap_or_default();
1970
1971        TotalTokenUsage {
1972            total: token_usage.total_tokens() as usize,
1973            max,
1974        }
1975    }
1976
1977    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1978        let model_registry = LanguageModelRegistry::read_global(cx);
1979        let Some(model) = model_registry.default_model() else {
1980            return TotalTokenUsage::default();
1981        };
1982
1983        let max = model.model.max_token_count();
1984
1985        if let Some(exceeded_error) = &self.exceeded_window_error {
1986            if model.model.id() == exceeded_error.model_id {
1987                return TotalTokenUsage {
1988                    total: exceeded_error.token_count,
1989                    max,
1990                };
1991            }
1992        }
1993
1994        let total = self
1995            .token_usage_at_last_message()
1996            .unwrap_or_default()
1997            .total_tokens() as usize;
1998
1999        TotalTokenUsage { total, max }
2000    }
2001
2002    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2003        self.request_token_usage
2004            .get(self.messages.len().saturating_sub(1))
2005            .or_else(|| self.request_token_usage.last())
2006            .cloned()
2007    }
2008
2009    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2010        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2011        self.request_token_usage
2012            .resize(self.messages.len(), placeholder);
2013
2014        if let Some(last) = self.request_token_usage.last_mut() {
2015            *last = token_usage;
2016        }
2017    }
2018
2019    pub fn deny_tool_use(
2020        &mut self,
2021        tool_use_id: LanguageModelToolUseId,
2022        tool_name: Arc<str>,
2023        cx: &mut Context<Self>,
2024    ) {
2025        let err = Err(anyhow::anyhow!(
2026            "Permission to run tool action denied by user"
2027        ));
2028
2029        self.tool_use
2030            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2031        self.tool_finished(tool_use_id.clone(), None, true, cx);
2032    }
2033}
2034
2035#[derive(Debug, Clone, Error)]
2036pub enum ThreadError {
2037    #[error("Payment required")]
2038    PaymentRequired,
2039    #[error("Max monthly spend reached")]
2040    MaxMonthlySpendReached,
2041    #[error("Model request limit reached")]
2042    ModelRequestLimitReached { plan: Plan },
2043    #[error("Message {header}: {message}")]
2044    Message {
2045        header: SharedString,
2046        message: SharedString,
2047    },
2048}
2049
2050#[derive(Debug, Clone)]
2051pub enum ThreadEvent {
2052    ShowError(ThreadError),
2053    UsageUpdated(RequestUsage),
2054    StreamedCompletion,
2055    StreamedAssistantText(MessageId, String),
2056    StreamedAssistantThinking(MessageId, String),
2057    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2058    MessageAdded(MessageId),
2059    MessageEdited(MessageId),
2060    MessageDeleted(MessageId),
2061    SummaryGenerated,
2062    SummaryChanged,
2063    UsePendingTools {
2064        tool_uses: Vec<PendingToolUse>,
2065    },
2066    ToolFinished {
2067        #[allow(unused)]
2068        tool_use_id: LanguageModelToolUseId,
2069        /// The pending tool use that corresponds to this tool.
2070        pending_tool_use: Option<PendingToolUse>,
2071    },
2072    CheckpointChanged,
2073    ToolConfirmationNeeded,
2074}
2075
2076impl EventEmitter<ThreadEvent> for Thread {}
2077
2078struct PendingCompletion {
2079    id: usize,
2080    _task: Task<()>,
2081}
2082
2083#[cfg(test)]
2084mod tests {
2085    use super::*;
2086    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2087    use assistant_settings::AssistantSettings;
2088    use context_server::ContextServerSettings;
2089    use editor::EditorSettings;
2090    use gpui::TestAppContext;
2091    use project::{FakeFs, Project};
2092    use prompt_store::PromptBuilder;
2093    use serde_json::json;
2094    use settings::{Settings, SettingsStore};
2095    use std::sync::Arc;
2096    use theme::ThemeSettings;
2097    use util::path;
2098    use workspace::Workspace;
2099
2100    #[gpui::test]
2101    async fn test_message_with_context(cx: &mut TestAppContext) {
2102        init_test_settings(cx);
2103
2104        let project = create_test_project(
2105            cx,
2106            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2107        )
2108        .await;
2109
2110        let (_workspace, _thread_store, thread, context_store) =
2111            setup_test_environment(cx, project.clone()).await;
2112
2113        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2114            .await
2115            .unwrap();
2116
2117        let context =
2118            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2119
2120        // Insert user message with context
2121        let message_id = thread.update(cx, |thread, cx| {
2122            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2123        });
2124
2125        // Check content and context in message object
2126        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2127
2128        // Use different path format strings based on platform for the test
2129        #[cfg(windows)]
2130        let path_part = r"test\code.rs";
2131        #[cfg(not(windows))]
2132        let path_part = "test/code.rs";
2133
2134        let expected_context = format!(
2135            r#"
2136<context>
2137The following items were attached by the user. You don't need to use other tools to read them.
2138
2139<files>
2140```rs {path_part}
2141fn main() {{
2142    println!("Hello, world!");
2143}}
2144```
2145</files>
2146</context>
2147"#
2148        );
2149
2150        assert_eq!(message.role, Role::User);
2151        assert_eq!(message.segments.len(), 1);
2152        assert_eq!(
2153            message.segments[0],
2154            MessageSegment::Text("Please explain this code".to_string())
2155        );
2156        assert_eq!(message.context, expected_context);
2157
2158        // Check message in request
2159        let request = thread.read_with(cx, |thread, cx| {
2160            thread.to_completion_request(RequestKind::Chat, cx)
2161        });
2162
2163        assert_eq!(request.messages.len(), 2);
2164        let expected_full_message = format!("{}Please explain this code", expected_context);
2165        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2166    }
2167
2168    #[gpui::test]
2169    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2170        init_test_settings(cx);
2171
2172        let project = create_test_project(
2173            cx,
2174            json!({
2175                "file1.rs": "fn function1() {}\n",
2176                "file2.rs": "fn function2() {}\n",
2177                "file3.rs": "fn function3() {}\n",
2178            }),
2179        )
2180        .await;
2181
2182        let (_, _thread_store, thread, context_store) =
2183            setup_test_environment(cx, project.clone()).await;
2184
2185        // Open files individually
2186        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2187            .await
2188            .unwrap();
2189        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2190            .await
2191            .unwrap();
2192        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2193            .await
2194            .unwrap();
2195
2196        // Get the context objects
2197        let contexts = context_store.update(cx, |store, _| store.context().clone());
2198        assert_eq!(contexts.len(), 3);
2199
2200        // First message with context 1
2201        let message1_id = thread.update(cx, |thread, cx| {
2202            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2203        });
2204
2205        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2206        let message2_id = thread.update(cx, |thread, cx| {
2207            thread.insert_user_message(
2208                "Message 2",
2209                vec![contexts[0].clone(), contexts[1].clone()],
2210                None,
2211                cx,
2212            )
2213        });
2214
2215        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2216        let message3_id = thread.update(cx, |thread, cx| {
2217            thread.insert_user_message(
2218                "Message 3",
2219                vec![
2220                    contexts[0].clone(),
2221                    contexts[1].clone(),
2222                    contexts[2].clone(),
2223                ],
2224                None,
2225                cx,
2226            )
2227        });
2228
2229        // Check what contexts are included in each message
2230        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2231            (
2232                thread.message(message1_id).unwrap().clone(),
2233                thread.message(message2_id).unwrap().clone(),
2234                thread.message(message3_id).unwrap().clone(),
2235            )
2236        });
2237
2238        // First message should include context 1
2239        assert!(message1.context.contains("file1.rs"));
2240
2241        // Second message should include only context 2 (not 1)
2242        assert!(!message2.context.contains("file1.rs"));
2243        assert!(message2.context.contains("file2.rs"));
2244
2245        // Third message should include only context 3 (not 1 or 2)
2246        assert!(!message3.context.contains("file1.rs"));
2247        assert!(!message3.context.contains("file2.rs"));
2248        assert!(message3.context.contains("file3.rs"));
2249
2250        // Check entire request to make sure all contexts are properly included
2251        let request = thread.read_with(cx, |thread, cx| {
2252            thread.to_completion_request(RequestKind::Chat, cx)
2253        });
2254
2255        // The request should contain all 3 messages
2256        assert_eq!(request.messages.len(), 4);
2257
2258        // Check that the contexts are properly formatted in each message
2259        assert!(request.messages[1].string_contents().contains("file1.rs"));
2260        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2261        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2262
2263        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2264        assert!(request.messages[2].string_contents().contains("file2.rs"));
2265        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2266
2267        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2268        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2269        assert!(request.messages[3].string_contents().contains("file3.rs"));
2270    }
2271
2272    #[gpui::test]
2273    async fn test_message_without_files(cx: &mut TestAppContext) {
2274        init_test_settings(cx);
2275
2276        let project = create_test_project(
2277            cx,
2278            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2279        )
2280        .await;
2281
2282        let (_, _thread_store, thread, _context_store) =
2283            setup_test_environment(cx, project.clone()).await;
2284
2285        // Insert user message without any context (empty context vector)
2286        let message_id = thread.update(cx, |thread, cx| {
2287            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2288        });
2289
2290        // Check content and context in message object
2291        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2292
2293        // Context should be empty when no files are included
2294        assert_eq!(message.role, Role::User);
2295        assert_eq!(message.segments.len(), 1);
2296        assert_eq!(
2297            message.segments[0],
2298            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2299        );
2300        assert_eq!(message.context, "");
2301
2302        // Check message in request
2303        let request = thread.read_with(cx, |thread, cx| {
2304            thread.to_completion_request(RequestKind::Chat, cx)
2305        });
2306
2307        assert_eq!(request.messages.len(), 2);
2308        assert_eq!(
2309            request.messages[1].string_contents(),
2310            "What is the best way to learn Rust?"
2311        );
2312
2313        // Add second message, also without context
2314        let message2_id = thread.update(cx, |thread, cx| {
2315            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2316        });
2317
2318        let message2 =
2319            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2320        assert_eq!(message2.context, "");
2321
2322        // Check that both messages appear in the request
2323        let request = thread.read_with(cx, |thread, cx| {
2324            thread.to_completion_request(RequestKind::Chat, cx)
2325        });
2326
2327        assert_eq!(request.messages.len(), 3);
2328        assert_eq!(
2329            request.messages[1].string_contents(),
2330            "What is the best way to learn Rust?"
2331        );
2332        assert_eq!(
2333            request.messages[2].string_contents(),
2334            "Are there any good books?"
2335        );
2336    }
2337
2338    #[gpui::test]
2339    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2340        init_test_settings(cx);
2341
2342        let project = create_test_project(
2343            cx,
2344            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2345        )
2346        .await;
2347
2348        let (_workspace, _thread_store, thread, context_store) =
2349            setup_test_environment(cx, project.clone()).await;
2350
2351        // Open buffer and add it to context
2352        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2353            .await
2354            .unwrap();
2355
2356        let context =
2357            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2358
2359        // Insert user message with the buffer as context
2360        thread.update(cx, |thread, cx| {
2361            thread.insert_user_message("Explain this code", vec![context], None, cx)
2362        });
2363
2364        // Create a request and check that it doesn't have a stale buffer warning yet
2365        let initial_request = thread.read_with(cx, |thread, cx| {
2366            thread.to_completion_request(RequestKind::Chat, cx)
2367        });
2368
2369        // Make sure we don't have a stale file warning yet
2370        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2371            msg.string_contents()
2372                .contains("These files changed since last read:")
2373        });
2374        assert!(
2375            !has_stale_warning,
2376            "Should not have stale buffer warning before buffer is modified"
2377        );
2378
2379        // Modify the buffer
2380        buffer.update(cx, |buffer, cx| {
2381            // Find a position at the end of line 1
2382            buffer.edit(
2383                [(1..1, "\n    println!(\"Added a new line\");\n")],
2384                None,
2385                cx,
2386            );
2387        });
2388
2389        // Insert another user message without context
2390        thread.update(cx, |thread, cx| {
2391            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2392        });
2393
2394        // Create a new request and check for the stale buffer warning
2395        let new_request = thread.read_with(cx, |thread, cx| {
2396            thread.to_completion_request(RequestKind::Chat, cx)
2397        });
2398
2399        // We should have a stale file warning as the last message
2400        let last_message = new_request
2401            .messages
2402            .last()
2403            .expect("Request should have messages");
2404
2405        // The last message should be the stale buffer notification
2406        assert_eq!(last_message.role, Role::User);
2407
2408        // Check the exact content of the message
2409        let expected_content = "These files changed since last read:\n- code.rs\n";
2410        assert_eq!(
2411            last_message.string_contents(),
2412            expected_content,
2413            "Last message should be exactly the stale buffer notification"
2414        );
2415    }
2416
2417    fn init_test_settings(cx: &mut TestAppContext) {
2418        cx.update(|cx| {
2419            let settings_store = SettingsStore::test(cx);
2420            cx.set_global(settings_store);
2421            language::init(cx);
2422            Project::init_settings(cx);
2423            AssistantSettings::register(cx);
2424            thread_store::init(cx);
2425            workspace::init_settings(cx);
2426            ThemeSettings::register(cx);
2427            ContextServerSettings::register(cx);
2428            EditorSettings::register(cx);
2429        });
2430    }
2431
2432    // Helper to create a test project with test files
2433    async fn create_test_project(
2434        cx: &mut TestAppContext,
2435        files: serde_json::Value,
2436    ) -> Entity<Project> {
2437        let fs = FakeFs::new(cx.executor());
2438        fs.insert_tree(path!("/test"), files).await;
2439        Project::test(fs, [path!("/test").as_ref()], cx).await
2440    }
2441
2442    async fn setup_test_environment(
2443        cx: &mut TestAppContext,
2444        project: Entity<Project>,
2445    ) -> (
2446        Entity<Workspace>,
2447        Entity<ThreadStore>,
2448        Entity<Thread>,
2449        Entity<ContextStore>,
2450    ) {
2451        let (workspace, cx) =
2452            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2453
2454        let thread_store = cx
2455            .update(|_, cx| {
2456                ThreadStore::load(
2457                    project.clone(),
2458                    cx.new(|_| ToolWorkingSet::default()),
2459                    Arc::new(PromptBuilder::new(None).unwrap()),
2460                    cx,
2461                )
2462            })
2463            .await;
2464
2465        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2466        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2467
2468        (workspace, thread_store, thread, context_store)
2469    }
2470
2471    async fn add_file_to_context(
2472        project: &Entity<Project>,
2473        context_store: &Entity<ContextStore>,
2474        path: &str,
2475        cx: &mut TestAppContext,
2476    ) -> Result<Entity<language::Buffer>> {
2477        let buffer_path = project
2478            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2479            .unwrap();
2480
2481        let buffer = project
2482            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2483            .await
2484            .unwrap();
2485
2486        context_store
2487            .update(cx, |store, cx| {
2488                store.add_file_from_buffer(buffer.clone(), cx)
2489            })
2490            .await?;
2491
2492        Ok(buffer)
2493    }
2494}