thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  42
  43#[derive(Debug, Clone, Copy)]
  44pub enum RequestKind {
  45    Chat,
  46    /// Used when summarizing a thread.
  47    Summarize,
  48}
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  74pub struct MessageId(pub(crate) usize);
  75
  76impl MessageId {
  77    fn post_inc(&mut self) -> Self {
  78        Self(post_inc(&mut self.0))
  79    }
  80}
  81
  82/// A message in a [`Thread`].
  83#[derive(Debug, Clone)]
  84pub struct Message {
  85    pub id: MessageId,
  86    pub role: Role,
  87    pub segments: Vec<MessageSegment>,
  88    pub context: String,
  89}
  90
  91impl Message {
  92    /// Returns whether the message contains any meaningful text that should be displayed
  93    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  94    pub fn should_display_content(&self) -> bool {
  95        self.segments.iter().all(|segment| segment.should_display())
  96    }
  97
  98    pub fn push_thinking(&mut self, text: &str) {
  99        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
 100            segment.push_str(text);
 101        } else {
 102            self.segments
 103                .push(MessageSegment::Thinking(text.to_string()));
 104        }
 105    }
 106
 107    pub fn push_text(&mut self, text: &str) {
 108        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 109            segment.push_str(text);
 110        } else {
 111            self.segments.push(MessageSegment::Text(text.to_string()));
 112        }
 113    }
 114
 115    pub fn to_string(&self) -> String {
 116        let mut result = String::new();
 117
 118        if !self.context.is_empty() {
 119            result.push_str(&self.context);
 120        }
 121
 122        for segment in &self.segments {
 123            match segment {
 124                MessageSegment::Text(text) => result.push_str(text),
 125                MessageSegment::Thinking(text) => {
 126                    result.push_str("<think>");
 127                    result.push_str(text);
 128                    result.push_str("</think>");
 129                }
 130            }
 131        }
 132
 133        result
 134    }
 135}
 136
 137#[derive(Debug, Clone, PartialEq, Eq)]
 138pub enum MessageSegment {
 139    Text(String),
 140    Thinking(String),
 141}
 142
 143impl MessageSegment {
 144    pub fn text_mut(&mut self) -> &mut String {
 145        match self {
 146            Self::Text(text) => text,
 147            Self::Thinking(text) => text,
 148        }
 149    }
 150
 151    pub fn should_display(&self) -> bool {
 152        // We add USING_TOOL_MARKER when making a request that includes tool uses
 153        // without non-whitespace text around them, and this can cause the model
 154        // to mimic the pattern, so we consider those segments not displayable.
 155        match self {
 156            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 157            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 158        }
 159    }
 160}
 161
 162#[derive(Debug, Clone, Serialize, Deserialize)]
 163pub struct ProjectSnapshot {
 164    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 165    pub unsaved_buffer_paths: Vec<String>,
 166    pub timestamp: DateTime<Utc>,
 167}
 168
 169#[derive(Debug, Clone, Serialize, Deserialize)]
 170pub struct WorktreeSnapshot {
 171    pub worktree_path: String,
 172    pub git_state: Option<GitState>,
 173}
 174
 175#[derive(Debug, Clone, Serialize, Deserialize)]
 176pub struct GitState {
 177    pub remote_url: Option<String>,
 178    pub head_sha: Option<String>,
 179    pub current_branch: Option<String>,
 180    pub diff: Option<String>,
 181}
 182
 183#[derive(Clone)]
 184pub struct ThreadCheckpoint {
 185    message_id: MessageId,
 186    git_checkpoint: GitStoreCheckpoint,
 187}
 188
 189#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 190pub enum ThreadFeedback {
 191    Positive,
 192    Negative,
 193}
 194
 195pub enum LastRestoreCheckpoint {
 196    Pending {
 197        message_id: MessageId,
 198    },
 199    Error {
 200        message_id: MessageId,
 201        error: String,
 202    },
 203}
 204
 205impl LastRestoreCheckpoint {
 206    pub fn message_id(&self) -> MessageId {
 207        match self {
 208            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 209            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 210        }
 211    }
 212}
 213
 214#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 215pub enum DetailedSummaryState {
 216    #[default]
 217    NotGenerated,
 218    Generating {
 219        message_id: MessageId,
 220    },
 221    Generated {
 222        text: SharedString,
 223        message_id: MessageId,
 224    },
 225}
 226
 227#[derive(Default)]
 228pub struct TotalTokenUsage {
 229    pub total: usize,
 230    pub max: usize,
 231}
 232
 233impl TotalTokenUsage {
 234    pub fn ratio(&self) -> TokenUsageRatio {
 235        #[cfg(debug_assertions)]
 236        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 237            .unwrap_or("0.8".to_string())
 238            .parse()
 239            .unwrap();
 240        #[cfg(not(debug_assertions))]
 241        let warning_threshold: f32 = 0.8;
 242
 243        if self.total >= self.max {
 244            TokenUsageRatio::Exceeded
 245        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 246            TokenUsageRatio::Warning
 247        } else {
 248            TokenUsageRatio::Normal
 249        }
 250    }
 251
 252    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 253        TotalTokenUsage {
 254            total: self.total + tokens,
 255            max: self.max,
 256        }
 257    }
 258}
 259
 260#[derive(Debug, Default, PartialEq, Eq)]
 261pub enum TokenUsageRatio {
 262    #[default]
 263    Normal,
 264    Warning,
 265    Exceeded,
 266}
 267
 268/// A thread of conversation with the LLM.
 269pub struct Thread {
 270    id: ThreadId,
 271    updated_at: DateTime<Utc>,
 272    summary: Option<SharedString>,
 273    pending_summary: Task<Option<()>>,
 274    detailed_summary_state: DetailedSummaryState,
 275    messages: Vec<Message>,
 276    next_message_id: MessageId,
 277    context: BTreeMap<ContextId, AssistantContext>,
 278    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 279    project_context: SharedProjectContext,
 280    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 281    completion_count: usize,
 282    pending_completions: Vec<PendingCompletion>,
 283    project: Entity<Project>,
 284    prompt_builder: Arc<PromptBuilder>,
 285    tools: Entity<ToolWorkingSet>,
 286    tool_use: ToolUseState,
 287    action_log: Entity<ActionLog>,
 288    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 289    pending_checkpoint: Option<ThreadCheckpoint>,
 290    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 291    request_token_usage: Vec<TokenUsage>,
 292    cumulative_token_usage: TokenUsage,
 293    exceeded_window_error: Option<ExceededWindowError>,
 294    feedback: Option<ThreadFeedback>,
 295    message_feedback: HashMap<MessageId, ThreadFeedback>,
 296    last_auto_capture_at: Option<Instant>,
 297}
 298
 299#[derive(Debug, Clone, Serialize, Deserialize)]
 300pub struct ExceededWindowError {
 301    /// Model used when last message exceeded context window
 302    model_id: LanguageModelId,
 303    /// Token count including last message
 304    token_count: usize,
 305}
 306
 307impl Thread {
 308    pub fn new(
 309        project: Entity<Project>,
 310        tools: Entity<ToolWorkingSet>,
 311        prompt_builder: Arc<PromptBuilder>,
 312        system_prompt: SharedProjectContext,
 313        cx: &mut Context<Self>,
 314    ) -> Self {
 315        Self {
 316            id: ThreadId::new(),
 317            updated_at: Utc::now(),
 318            summary: None,
 319            pending_summary: Task::ready(None),
 320            detailed_summary_state: DetailedSummaryState::NotGenerated,
 321            messages: Vec::new(),
 322            next_message_id: MessageId(0),
 323            context: BTreeMap::default(),
 324            context_by_message: HashMap::default(),
 325            project_context: system_prompt,
 326            checkpoints_by_message: HashMap::default(),
 327            completion_count: 0,
 328            pending_completions: Vec::new(),
 329            project: project.clone(),
 330            prompt_builder,
 331            tools: tools.clone(),
 332            last_restore_checkpoint: None,
 333            pending_checkpoint: None,
 334            tool_use: ToolUseState::new(tools.clone()),
 335            action_log: cx.new(|_| ActionLog::new(project.clone())),
 336            initial_project_snapshot: {
 337                let project_snapshot = Self::project_snapshot(project, cx);
 338                cx.foreground_executor()
 339                    .spawn(async move { Some(project_snapshot.await) })
 340                    .shared()
 341            },
 342            request_token_usage: Vec::new(),
 343            cumulative_token_usage: TokenUsage::default(),
 344            exceeded_window_error: None,
 345            feedback: None,
 346            message_feedback: HashMap::default(),
 347            last_auto_capture_at: None,
 348        }
 349    }
 350
 351    pub fn deserialize(
 352        id: ThreadId,
 353        serialized: SerializedThread,
 354        project: Entity<Project>,
 355        tools: Entity<ToolWorkingSet>,
 356        prompt_builder: Arc<PromptBuilder>,
 357        project_context: SharedProjectContext,
 358        cx: &mut Context<Self>,
 359    ) -> Self {
 360        let next_message_id = MessageId(
 361            serialized
 362                .messages
 363                .last()
 364                .map(|message| message.id.0 + 1)
 365                .unwrap_or(0),
 366        );
 367        let tool_use =
 368            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 369
 370        Self {
 371            id,
 372            updated_at: serialized.updated_at,
 373            summary: Some(serialized.summary),
 374            pending_summary: Task::ready(None),
 375            detailed_summary_state: serialized.detailed_summary_state,
 376            messages: serialized
 377                .messages
 378                .into_iter()
 379                .map(|message| Message {
 380                    id: message.id,
 381                    role: message.role,
 382                    segments: message
 383                        .segments
 384                        .into_iter()
 385                        .map(|segment| match segment {
 386                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 387                            SerializedMessageSegment::Thinking { text } => {
 388                                MessageSegment::Thinking(text)
 389                            }
 390                        })
 391                        .collect(),
 392                    context: message.context,
 393                })
 394                .collect(),
 395            next_message_id,
 396            context: BTreeMap::default(),
 397            context_by_message: HashMap::default(),
 398            project_context,
 399            checkpoints_by_message: HashMap::default(),
 400            completion_count: 0,
 401            pending_completions: Vec::new(),
 402            last_restore_checkpoint: None,
 403            pending_checkpoint: None,
 404            project: project.clone(),
 405            prompt_builder,
 406            tools,
 407            tool_use,
 408            action_log: cx.new(|_| ActionLog::new(project)),
 409            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 410            request_token_usage: serialized.request_token_usage,
 411            cumulative_token_usage: serialized.cumulative_token_usage,
 412            exceeded_window_error: None,
 413            feedback: None,
 414            message_feedback: HashMap::default(),
 415            last_auto_capture_at: None,
 416        }
 417    }
 418
 419    pub fn id(&self) -> &ThreadId {
 420        &self.id
 421    }
 422
 423    pub fn is_empty(&self) -> bool {
 424        self.messages.is_empty()
 425    }
 426
 427    pub fn updated_at(&self) -> DateTime<Utc> {
 428        self.updated_at
 429    }
 430
 431    pub fn touch_updated_at(&mut self) {
 432        self.updated_at = Utc::now();
 433    }
 434
 435    pub fn summary(&self) -> Option<SharedString> {
 436        self.summary.clone()
 437    }
 438
 439    pub fn project_context(&self) -> SharedProjectContext {
 440        self.project_context.clone()
 441    }
 442
 443    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 444
 445    pub fn summary_or_default(&self) -> SharedString {
 446        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 447    }
 448
 449    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 450        let Some(current_summary) = &self.summary else {
 451            // Don't allow setting summary until generated
 452            return;
 453        };
 454
 455        let mut new_summary = new_summary.into();
 456
 457        if new_summary.is_empty() {
 458            new_summary = Self::DEFAULT_SUMMARY;
 459        }
 460
 461        if current_summary != &new_summary {
 462            self.summary = Some(new_summary);
 463            cx.emit(ThreadEvent::SummaryChanged);
 464        }
 465    }
 466
 467    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 468        self.latest_detailed_summary()
 469            .unwrap_or_else(|| self.text().into())
 470    }
 471
 472    fn latest_detailed_summary(&self) -> Option<SharedString> {
 473        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 474            Some(text.clone())
 475        } else {
 476            None
 477        }
 478    }
 479
 480    pub fn message(&self, id: MessageId) -> Option<&Message> {
 481        self.messages.iter().find(|message| message.id == id)
 482    }
 483
 484    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 485        self.messages.iter()
 486    }
 487
 488    pub fn is_generating(&self) -> bool {
 489        !self.pending_completions.is_empty() || !self.all_tools_finished()
 490    }
 491
 492    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 493        &self.tools
 494    }
 495
 496    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 497        self.tool_use
 498            .pending_tool_uses()
 499            .into_iter()
 500            .find(|tool_use| &tool_use.id == id)
 501    }
 502
 503    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 504        self.tool_use
 505            .pending_tool_uses()
 506            .into_iter()
 507            .filter(|tool_use| tool_use.status.needs_confirmation())
 508    }
 509
 510    pub fn has_pending_tool_uses(&self) -> bool {
 511        !self.tool_use.pending_tool_uses().is_empty()
 512    }
 513
 514    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 515        self.checkpoints_by_message.get(&id).cloned()
 516    }
 517
 518    pub fn restore_checkpoint(
 519        &mut self,
 520        checkpoint: ThreadCheckpoint,
 521        cx: &mut Context<Self>,
 522    ) -> Task<Result<()>> {
 523        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 524            message_id: checkpoint.message_id,
 525        });
 526        cx.emit(ThreadEvent::CheckpointChanged);
 527        cx.notify();
 528
 529        let git_store = self.project().read(cx).git_store().clone();
 530        let restore = git_store.update(cx, |git_store, cx| {
 531            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 532        });
 533
 534        cx.spawn(async move |this, cx| {
 535            let result = restore.await;
 536            this.update(cx, |this, cx| {
 537                if let Err(err) = result.as_ref() {
 538                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 539                        message_id: checkpoint.message_id,
 540                        error: err.to_string(),
 541                    });
 542                } else {
 543                    this.truncate(checkpoint.message_id, cx);
 544                    this.last_restore_checkpoint = None;
 545                }
 546                this.pending_checkpoint = None;
 547                cx.emit(ThreadEvent::CheckpointChanged);
 548                cx.notify();
 549            })?;
 550            result
 551        })
 552    }
 553
 554    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 555        let pending_checkpoint = if self.is_generating() {
 556            return;
 557        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 558            checkpoint
 559        } else {
 560            return;
 561        };
 562
 563        let git_store = self.project.read(cx).git_store().clone();
 564        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 565        cx.spawn(async move |this, cx| match final_checkpoint.await {
 566            Ok(final_checkpoint) => {
 567                let equal = git_store
 568                    .update(cx, |store, cx| {
 569                        store.compare_checkpoints(
 570                            pending_checkpoint.git_checkpoint.clone(),
 571                            final_checkpoint.clone(),
 572                            cx,
 573                        )
 574                    })?
 575                    .await
 576                    .unwrap_or(false);
 577
 578                if equal {
 579                    git_store
 580                        .update(cx, |store, cx| {
 581                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 582                        })?
 583                        .detach();
 584                } else {
 585                    this.update(cx, |this, cx| {
 586                        this.insert_checkpoint(pending_checkpoint, cx)
 587                    })?;
 588                }
 589
 590                git_store
 591                    .update(cx, |store, cx| {
 592                        store.delete_checkpoint(final_checkpoint, cx)
 593                    })?
 594                    .detach();
 595
 596                Ok(())
 597            }
 598            Err(_) => this.update(cx, |this, cx| {
 599                this.insert_checkpoint(pending_checkpoint, cx)
 600            }),
 601        })
 602        .detach();
 603    }
 604
 605    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 606        self.checkpoints_by_message
 607            .insert(checkpoint.message_id, checkpoint);
 608        cx.emit(ThreadEvent::CheckpointChanged);
 609        cx.notify();
 610    }
 611
 612    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 613        self.last_restore_checkpoint.as_ref()
 614    }
 615
 616    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 617        let Some(message_ix) = self
 618            .messages
 619            .iter()
 620            .rposition(|message| message.id == message_id)
 621        else {
 622            return;
 623        };
 624        for deleted_message in self.messages.drain(message_ix..) {
 625            self.context_by_message.remove(&deleted_message.id);
 626            self.checkpoints_by_message.remove(&deleted_message.id);
 627        }
 628        cx.notify();
 629    }
 630
 631    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 632        self.context_by_message
 633            .get(&id)
 634            .into_iter()
 635            .flat_map(|context| {
 636                context
 637                    .iter()
 638                    .filter_map(|context_id| self.context.get(&context_id))
 639            })
 640    }
 641
 642    /// Returns whether all of the tool uses have finished running.
 643    pub fn all_tools_finished(&self) -> bool {
 644        // If the only pending tool uses left are the ones with errors, then
 645        // that means that we've finished running all of the pending tools.
 646        self.tool_use
 647            .pending_tool_uses()
 648            .iter()
 649            .all(|tool_use| tool_use.status.is_error())
 650    }
 651
 652    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 653        self.tool_use.tool_uses_for_message(id, cx)
 654    }
 655
 656    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 657        self.tool_use.tool_results_for_message(id)
 658    }
 659
 660    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 661        self.tool_use.tool_result(id)
 662    }
 663
 664    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 665        Some(&self.tool_use.tool_result(id)?.content)
 666    }
 667
 668    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 669        self.tool_use.tool_result_card(id).cloned()
 670    }
 671
 672    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 673        self.tool_use.message_has_tool_results(message_id)
 674    }
 675
 676    /// Filter out contexts that have already been included in previous messages
 677    pub fn filter_new_context<'a>(
 678        &self,
 679        context: impl Iterator<Item = &'a AssistantContext>,
 680    ) -> impl Iterator<Item = &'a AssistantContext> {
 681        context.filter(|ctx| self.is_context_new(ctx))
 682    }
 683
 684    fn is_context_new(&self, context: &AssistantContext) -> bool {
 685        !self.context.contains_key(&context.id())
 686    }
 687
 688    pub fn insert_user_message(
 689        &mut self,
 690        text: impl Into<String>,
 691        context: Vec<AssistantContext>,
 692        git_checkpoint: Option<GitStoreCheckpoint>,
 693        cx: &mut Context<Self>,
 694    ) -> MessageId {
 695        let text = text.into();
 696
 697        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 698
 699        let new_context: Vec<_> = context
 700            .into_iter()
 701            .filter(|ctx| self.is_context_new(ctx))
 702            .collect();
 703
 704        if !new_context.is_empty() {
 705            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 706                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 707                    message.context = context_string;
 708                }
 709            }
 710
 711            self.action_log.update(cx, |log, cx| {
 712                // Track all buffers added as context
 713                for ctx in &new_context {
 714                    match ctx {
 715                        AssistantContext::File(file_ctx) => {
 716                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 717                        }
 718                        AssistantContext::Directory(dir_ctx) => {
 719                            for context_buffer in &dir_ctx.context_buffers {
 720                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 721                            }
 722                        }
 723                        AssistantContext::Symbol(symbol_ctx) => {
 724                            log.buffer_added_as_context(
 725                                symbol_ctx.context_symbol.buffer.clone(),
 726                                cx,
 727                            );
 728                        }
 729                        AssistantContext::Excerpt(excerpt_context) => {
 730                            log.buffer_added_as_context(
 731                                excerpt_context.context_buffer.buffer.clone(),
 732                                cx,
 733                            );
 734                        }
 735                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 736                    }
 737                }
 738            });
 739        }
 740
 741        let context_ids = new_context
 742            .iter()
 743            .map(|context| context.id())
 744            .collect::<Vec<_>>();
 745        self.context.extend(
 746            new_context
 747                .into_iter()
 748                .map(|context| (context.id(), context)),
 749        );
 750        self.context_by_message.insert(message_id, context_ids);
 751
 752        if let Some(git_checkpoint) = git_checkpoint {
 753            self.pending_checkpoint = Some(ThreadCheckpoint {
 754                message_id,
 755                git_checkpoint,
 756            });
 757        }
 758
 759        self.auto_capture_telemetry(cx);
 760
 761        message_id
 762    }
 763
 764    pub fn insert_message(
 765        &mut self,
 766        role: Role,
 767        segments: Vec<MessageSegment>,
 768        cx: &mut Context<Self>,
 769    ) -> MessageId {
 770        let id = self.next_message_id.post_inc();
 771        self.messages.push(Message {
 772            id,
 773            role,
 774            segments,
 775            context: String::new(),
 776        });
 777        self.touch_updated_at();
 778        cx.emit(ThreadEvent::MessageAdded(id));
 779        id
 780    }
 781
 782    pub fn edit_message(
 783        &mut self,
 784        id: MessageId,
 785        new_role: Role,
 786        new_segments: Vec<MessageSegment>,
 787        cx: &mut Context<Self>,
 788    ) -> bool {
 789        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 790            return false;
 791        };
 792        message.role = new_role;
 793        message.segments = new_segments;
 794        self.touch_updated_at();
 795        cx.emit(ThreadEvent::MessageEdited(id));
 796        true
 797    }
 798
 799    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 800        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 801            return false;
 802        };
 803        self.messages.remove(index);
 804        self.context_by_message.remove(&id);
 805        self.touch_updated_at();
 806        cx.emit(ThreadEvent::MessageDeleted(id));
 807        true
 808    }
 809
 810    /// Returns the representation of this [`Thread`] in a textual form.
 811    ///
 812    /// This is the representation we use when attaching a thread as context to another thread.
 813    pub fn text(&self) -> String {
 814        let mut text = String::new();
 815
 816        for message in &self.messages {
 817            text.push_str(match message.role {
 818                language_model::Role::User => "User:",
 819                language_model::Role::Assistant => "Assistant:",
 820                language_model::Role::System => "System:",
 821            });
 822            text.push('\n');
 823
 824            for segment in &message.segments {
 825                match segment {
 826                    MessageSegment::Text(content) => text.push_str(content),
 827                    MessageSegment::Thinking(content) => {
 828                        text.push_str(&format!("<think>{}</think>", content))
 829                    }
 830                }
 831            }
 832            text.push('\n');
 833        }
 834
 835        text
 836    }
 837
 838    /// Serializes this thread into a format for storage or telemetry.
 839    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 840        let initial_project_snapshot = self.initial_project_snapshot.clone();
 841        cx.spawn(async move |this, cx| {
 842            let initial_project_snapshot = initial_project_snapshot.await;
 843            this.read_with(cx, |this, cx| SerializedThread {
 844                version: SerializedThread::VERSION.to_string(),
 845                summary: this.summary_or_default(),
 846                updated_at: this.updated_at(),
 847                messages: this
 848                    .messages()
 849                    .map(|message| SerializedMessage {
 850                        id: message.id,
 851                        role: message.role,
 852                        segments: message
 853                            .segments
 854                            .iter()
 855                            .map(|segment| match segment {
 856                                MessageSegment::Text(text) => {
 857                                    SerializedMessageSegment::Text { text: text.clone() }
 858                                }
 859                                MessageSegment::Thinking(text) => {
 860                                    SerializedMessageSegment::Thinking { text: text.clone() }
 861                                }
 862                            })
 863                            .collect(),
 864                        tool_uses: this
 865                            .tool_uses_for_message(message.id, cx)
 866                            .into_iter()
 867                            .map(|tool_use| SerializedToolUse {
 868                                id: tool_use.id,
 869                                name: tool_use.name,
 870                                input: tool_use.input,
 871                            })
 872                            .collect(),
 873                        tool_results: this
 874                            .tool_results_for_message(message.id)
 875                            .into_iter()
 876                            .map(|tool_result| SerializedToolResult {
 877                                tool_use_id: tool_result.tool_use_id.clone(),
 878                                is_error: tool_result.is_error,
 879                                content: tool_result.content.clone(),
 880                            })
 881                            .collect(),
 882                        context: message.context.clone(),
 883                    })
 884                    .collect(),
 885                initial_project_snapshot,
 886                cumulative_token_usage: this.cumulative_token_usage,
 887                request_token_usage: this.request_token_usage.clone(),
 888                detailed_summary_state: this.detailed_summary_state.clone(),
 889                exceeded_window_error: this.exceeded_window_error.clone(),
 890            })
 891        })
 892    }
 893
 894    pub fn send_to_model(
 895        &mut self,
 896        model: Arc<dyn LanguageModel>,
 897        request_kind: RequestKind,
 898        cx: &mut Context<Self>,
 899    ) {
 900        let mut request = self.to_completion_request(request_kind, cx);
 901        if model.supports_tools() {
 902            request.tools = {
 903                let mut tools = Vec::new();
 904                tools.extend(
 905                    self.tools()
 906                        .read(cx)
 907                        .enabled_tools(cx)
 908                        .into_iter()
 909                        .filter_map(|tool| {
 910                            // Skip tools that cannot be supported
 911                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 912                            Some(LanguageModelRequestTool {
 913                                name: tool.name(),
 914                                description: tool.description(),
 915                                input_schema,
 916                            })
 917                        }),
 918                );
 919
 920                tools
 921            };
 922        }
 923
 924        self.stream_completion(request, model, cx);
 925    }
 926
 927    pub fn used_tools_since_last_user_message(&self) -> bool {
 928        for message in self.messages.iter().rev() {
 929            if self.tool_use.message_has_tool_results(message.id) {
 930                return true;
 931            } else if message.role == Role::User {
 932                return false;
 933            }
 934        }
 935
 936        false
 937    }
 938
 939    pub fn to_completion_request(
 940        &self,
 941        request_kind: RequestKind,
 942        cx: &App,
 943    ) -> LanguageModelRequest {
 944        let mut request = LanguageModelRequest {
 945            messages: vec![],
 946            tools: Vec::new(),
 947            stop: Vec::new(),
 948            temperature: None,
 949        };
 950
 951        if let Some(project_context) = self.project_context.borrow().as_ref() {
 952            if let Some(system_prompt) = self
 953                .prompt_builder
 954                .generate_assistant_system_prompt(project_context)
 955                .context("failed to generate assistant system prompt")
 956                .log_err()
 957            {
 958                request.messages.push(LanguageModelRequestMessage {
 959                    role: Role::System,
 960                    content: vec![MessageContent::Text(system_prompt)],
 961                    cache: true,
 962                });
 963            }
 964        } else {
 965            log::error!("project_context not set.")
 966        }
 967
 968        for message in &self.messages {
 969            let mut request_message = LanguageModelRequestMessage {
 970                role: message.role,
 971                content: Vec::new(),
 972                cache: false,
 973            };
 974
 975            match request_kind {
 976                RequestKind::Chat => {
 977                    self.tool_use
 978                        .attach_tool_results(message.id, &mut request_message);
 979                }
 980                RequestKind::Summarize => {
 981                    // We don't care about tool use during summarization.
 982                    if self.tool_use.message_has_tool_results(message.id) {
 983                        continue;
 984                    }
 985                }
 986            }
 987
 988            if !message.segments.is_empty() {
 989                request_message
 990                    .content
 991                    .push(MessageContent::Text(message.to_string()));
 992            }
 993
 994            match request_kind {
 995                RequestKind::Chat => {
 996                    self.tool_use
 997                        .attach_tool_uses(message.id, &mut request_message);
 998                }
 999                RequestKind::Summarize => {
1000                    // We don't care about tool use during summarization.
1001                }
1002            };
1003
1004            request.messages.push(request_message);
1005        }
1006
1007        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1008        if let Some(last) = request.messages.last_mut() {
1009            last.cache = true;
1010        }
1011
1012        self.attached_tracked_files_state(&mut request.messages, cx);
1013
1014        request
1015    }
1016
1017    fn attached_tracked_files_state(
1018        &self,
1019        messages: &mut Vec<LanguageModelRequestMessage>,
1020        cx: &App,
1021    ) {
1022        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1023
1024        let mut stale_message = String::new();
1025
1026        let action_log = self.action_log.read(cx);
1027
1028        for stale_file in action_log.stale_buffers(cx) {
1029            let Some(file) = stale_file.read(cx).file() else {
1030                continue;
1031            };
1032
1033            if stale_message.is_empty() {
1034                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1035            }
1036
1037            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1038        }
1039
1040        let mut content = Vec::with_capacity(2);
1041
1042        if !stale_message.is_empty() {
1043            content.push(stale_message.into());
1044        }
1045
1046        if action_log.has_edited_files_since_project_diagnostics_check() {
1047            content.push(
1048                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1049                and fix all errors AND warnings you introduced! \
1050                DO NOT mention you're going to do this until you're done."
1051                    .into(),
1052            );
1053        }
1054
1055        if !content.is_empty() {
1056            let context_message = LanguageModelRequestMessage {
1057                role: Role::User,
1058                content,
1059                cache: false,
1060            };
1061
1062            messages.push(context_message);
1063        }
1064    }
1065
1066    pub fn stream_completion(
1067        &mut self,
1068        request: LanguageModelRequest,
1069        model: Arc<dyn LanguageModel>,
1070        cx: &mut Context<Self>,
1071    ) {
1072        let pending_completion_id = post_inc(&mut self.completion_count);
1073        let task = cx.spawn(async move |thread, cx| {
1074            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1075            let initial_token_usage =
1076                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1077            let stream_completion = async {
1078                let (mut events, usage) = stream_completion_future.await?;
1079                let mut stop_reason = StopReason::EndTurn;
1080                let mut current_token_usage = TokenUsage::default();
1081
1082                if let Some(usage) = usage {
1083                    thread
1084                        .update(cx, |_thread, cx| {
1085                            cx.emit(ThreadEvent::UsageUpdated(usage));
1086                        })
1087                        .ok();
1088                }
1089
1090                while let Some(event) = events.next().await {
1091                    let event = event?;
1092
1093                    thread.update(cx, |thread, cx| {
1094                        match event {
1095                            LanguageModelCompletionEvent::StartMessage { .. } => {
1096                                thread.insert_message(
1097                                    Role::Assistant,
1098                                    vec![MessageSegment::Text(String::new())],
1099                                    cx,
1100                                );
1101                            }
1102                            LanguageModelCompletionEvent::Stop(reason) => {
1103                                stop_reason = reason;
1104                            }
1105                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1106                                thread.update_token_usage_at_last_message(token_usage);
1107                                thread.cumulative_token_usage = thread.cumulative_token_usage
1108                                    + token_usage
1109                                    - current_token_usage;
1110                                current_token_usage = token_usage;
1111                            }
1112                            LanguageModelCompletionEvent::Text(chunk) => {
1113                                if let Some(last_message) = thread.messages.last_mut() {
1114                                    if last_message.role == Role::Assistant {
1115                                        last_message.push_text(&chunk);
1116                                        cx.emit(ThreadEvent::StreamedAssistantText(
1117                                            last_message.id,
1118                                            chunk,
1119                                        ));
1120                                    } else {
1121                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1122                                        // of a new Assistant response.
1123                                        //
1124                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1125                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1126                                        thread.insert_message(
1127                                            Role::Assistant,
1128                                            vec![MessageSegment::Text(chunk.to_string())],
1129                                            cx,
1130                                        );
1131                                    };
1132                                }
1133                            }
1134                            LanguageModelCompletionEvent::Thinking(chunk) => {
1135                                if let Some(last_message) = thread.messages.last_mut() {
1136                                    if last_message.role == Role::Assistant {
1137                                        last_message.push_thinking(&chunk);
1138                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1139                                            last_message.id,
1140                                            chunk,
1141                                        ));
1142                                    } else {
1143                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1144                                        // of a new Assistant response.
1145                                        //
1146                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1147                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1148                                        thread.insert_message(
1149                                            Role::Assistant,
1150                                            vec![MessageSegment::Thinking(chunk.to_string())],
1151                                            cx,
1152                                        );
1153                                    };
1154                                }
1155                            }
1156                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1157                                let last_assistant_message_id = thread
1158                                    .messages
1159                                    .iter_mut()
1160                                    .rfind(|message| message.role == Role::Assistant)
1161                                    .map(|message| message.id)
1162                                    .unwrap_or_else(|| {
1163                                        thread.insert_message(Role::Assistant, vec![], cx)
1164                                    });
1165
1166                                thread.tool_use.request_tool_use(
1167                                    last_assistant_message_id,
1168                                    tool_use,
1169                                    cx,
1170                                );
1171                            }
1172                        }
1173
1174                        thread.touch_updated_at();
1175                        cx.emit(ThreadEvent::StreamedCompletion);
1176                        cx.notify();
1177
1178                        thread.auto_capture_telemetry(cx);
1179                    })?;
1180
1181                    smol::future::yield_now().await;
1182                }
1183
1184                thread.update(cx, |thread, cx| {
1185                    thread
1186                        .pending_completions
1187                        .retain(|completion| completion.id != pending_completion_id);
1188
1189                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1190                        thread.summarize(cx);
1191                    }
1192                })?;
1193
1194                anyhow::Ok(stop_reason)
1195            };
1196
1197            let result = stream_completion.await;
1198
1199            thread
1200                .update(cx, |thread, cx| {
1201                    thread.finalize_pending_checkpoint(cx);
1202                    match result.as_ref() {
1203                        Ok(stop_reason) => match stop_reason {
1204                            StopReason::ToolUse => {
1205                                let tool_uses = thread.use_pending_tools(cx);
1206                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1207                            }
1208                            StopReason::EndTurn => {}
1209                            StopReason::MaxTokens => {}
1210                        },
1211                        Err(error) => {
1212                            if error.is::<PaymentRequiredError>() {
1213                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1214                            } else if error.is::<MaxMonthlySpendReachedError>() {
1215                                cx.emit(ThreadEvent::ShowError(
1216                                    ThreadError::MaxMonthlySpendReached,
1217                                ));
1218                            } else if let Some(error) =
1219                                error.downcast_ref::<ModelRequestLimitReachedError>()
1220                            {
1221                                cx.emit(ThreadEvent::ShowError(
1222                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1223                                ));
1224                            } else if let Some(known_error) =
1225                                error.downcast_ref::<LanguageModelKnownError>()
1226                            {
1227                                match known_error {
1228                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1229                                        tokens,
1230                                    } => {
1231                                        thread.exceeded_window_error = Some(ExceededWindowError {
1232                                            model_id: model.id(),
1233                                            token_count: *tokens,
1234                                        });
1235                                        cx.notify();
1236                                    }
1237                                }
1238                            } else {
1239                                let error_message = error
1240                                    .chain()
1241                                    .map(|err| err.to_string())
1242                                    .collect::<Vec<_>>()
1243                                    .join("\n");
1244                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1245                                    header: "Error interacting with language model".into(),
1246                                    message: SharedString::from(error_message.clone()),
1247                                }));
1248                            }
1249
1250                            thread.cancel_last_completion(cx);
1251                        }
1252                    }
1253                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1254
1255                    thread.auto_capture_telemetry(cx);
1256
1257                    if let Ok(initial_usage) = initial_token_usage {
1258                        let usage = thread.cumulative_token_usage - initial_usage;
1259
1260                        telemetry::event!(
1261                            "Assistant Thread Completion",
1262                            thread_id = thread.id().to_string(),
1263                            model = model.telemetry_id(),
1264                            model_provider = model.provider_id().to_string(),
1265                            input_tokens = usage.input_tokens,
1266                            output_tokens = usage.output_tokens,
1267                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1268                            cache_read_input_tokens = usage.cache_read_input_tokens,
1269                        );
1270                    }
1271                })
1272                .ok();
1273        });
1274
1275        self.pending_completions.push(PendingCompletion {
1276            id: pending_completion_id,
1277            _task: task,
1278        });
1279    }
1280
1281    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1282        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1283            return;
1284        };
1285
1286        if !model.provider.is_authenticated(cx) {
1287            return;
1288        }
1289
1290        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1291        request.messages.push(LanguageModelRequestMessage {
1292            role: Role::User,
1293            content: vec![
1294                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1295                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1296                 If the conversation is about a specific subject, include it in the title. \
1297                 Be descriptive. DO NOT speak in the first person."
1298                    .into(),
1299            ],
1300            cache: false,
1301        });
1302
1303        self.pending_summary = cx.spawn(async move |this, cx| {
1304            async move {
1305                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1306                let (mut messages, usage) = stream.await?;
1307
1308                if let Some(usage) = usage {
1309                    this.update(cx, |_thread, cx| {
1310                        cx.emit(ThreadEvent::UsageUpdated(usage));
1311                    })
1312                    .ok();
1313                }
1314
1315                let mut new_summary = String::new();
1316                while let Some(message) = messages.stream.next().await {
1317                    let text = message?;
1318                    let mut lines = text.lines();
1319                    new_summary.extend(lines.next());
1320
1321                    // Stop if the LLM generated multiple lines.
1322                    if lines.next().is_some() {
1323                        break;
1324                    }
1325                }
1326
1327                this.update(cx, |this, cx| {
1328                    if !new_summary.is_empty() {
1329                        this.summary = Some(new_summary.into());
1330                    }
1331
1332                    cx.emit(ThreadEvent::SummaryGenerated);
1333                })?;
1334
1335                anyhow::Ok(())
1336            }
1337            .log_err()
1338            .await
1339        });
1340    }
1341
1342    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1343        let last_message_id = self.messages.last().map(|message| message.id)?;
1344
1345        match &self.detailed_summary_state {
1346            DetailedSummaryState::Generating { message_id, .. }
1347            | DetailedSummaryState::Generated { message_id, .. }
1348                if *message_id == last_message_id =>
1349            {
1350                // Already up-to-date
1351                return None;
1352            }
1353            _ => {}
1354        }
1355
1356        let ConfiguredModel { model, provider } =
1357            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1358
1359        if !provider.is_authenticated(cx) {
1360            return None;
1361        }
1362
1363        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1364
1365        request.messages.push(LanguageModelRequestMessage {
1366            role: Role::User,
1367            content: vec![
1368                "Generate a detailed summary of this conversation. Include:\n\
1369                1. A brief overview of what was discussed\n\
1370                2. Key facts or information discovered\n\
1371                3. Outcomes or conclusions reached\n\
1372                4. Any action items or next steps if any\n\
1373                Format it in Markdown with headings and bullet points."
1374                    .into(),
1375            ],
1376            cache: false,
1377        });
1378
1379        let task = cx.spawn(async move |thread, cx| {
1380            let stream = model.stream_completion_text(request, &cx);
1381            let Some(mut messages) = stream.await.log_err() else {
1382                thread
1383                    .update(cx, |this, _cx| {
1384                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1385                    })
1386                    .log_err();
1387
1388                return;
1389            };
1390
1391            let mut new_detailed_summary = String::new();
1392
1393            while let Some(chunk) = messages.stream.next().await {
1394                if let Some(chunk) = chunk.log_err() {
1395                    new_detailed_summary.push_str(&chunk);
1396                }
1397            }
1398
1399            thread
1400                .update(cx, |this, _cx| {
1401                    this.detailed_summary_state = DetailedSummaryState::Generated {
1402                        text: new_detailed_summary.into(),
1403                        message_id: last_message_id,
1404                    };
1405                })
1406                .log_err();
1407        });
1408
1409        self.detailed_summary_state = DetailedSummaryState::Generating {
1410            message_id: last_message_id,
1411        };
1412
1413        Some(task)
1414    }
1415
1416    pub fn is_generating_detailed_summary(&self) -> bool {
1417        matches!(
1418            self.detailed_summary_state,
1419            DetailedSummaryState::Generating { .. }
1420        )
1421    }
1422
1423    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1424        self.auto_capture_telemetry(cx);
1425        let request = self.to_completion_request(RequestKind::Chat, cx);
1426        let messages = Arc::new(request.messages);
1427        let pending_tool_uses = self
1428            .tool_use
1429            .pending_tool_uses()
1430            .into_iter()
1431            .filter(|tool_use| tool_use.status.is_idle())
1432            .cloned()
1433            .collect::<Vec<_>>();
1434
1435        for tool_use in pending_tool_uses.iter() {
1436            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1437                if tool.needs_confirmation(&tool_use.input, cx)
1438                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1439                {
1440                    self.tool_use.confirm_tool_use(
1441                        tool_use.id.clone(),
1442                        tool_use.ui_text.clone(),
1443                        tool_use.input.clone(),
1444                        messages.clone(),
1445                        tool,
1446                    );
1447                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1448                } else {
1449                    self.run_tool(
1450                        tool_use.id.clone(),
1451                        tool_use.ui_text.clone(),
1452                        tool_use.input.clone(),
1453                        &messages,
1454                        tool,
1455                        cx,
1456                    );
1457                }
1458            }
1459        }
1460
1461        pending_tool_uses
1462    }
1463
1464    pub fn run_tool(
1465        &mut self,
1466        tool_use_id: LanguageModelToolUseId,
1467        ui_text: impl Into<SharedString>,
1468        input: serde_json::Value,
1469        messages: &[LanguageModelRequestMessage],
1470        tool: Arc<dyn Tool>,
1471        cx: &mut Context<Thread>,
1472    ) {
1473        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1474        self.tool_use
1475            .run_pending_tool(tool_use_id, ui_text.into(), task);
1476    }
1477
1478    fn spawn_tool_use(
1479        &mut self,
1480        tool_use_id: LanguageModelToolUseId,
1481        messages: &[LanguageModelRequestMessage],
1482        input: serde_json::Value,
1483        tool: Arc<dyn Tool>,
1484        cx: &mut Context<Thread>,
1485    ) -> Task<()> {
1486        let tool_name: Arc<str> = tool.name().into();
1487
1488        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1489            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1490        } else {
1491            tool.run(
1492                input,
1493                messages,
1494                self.project.clone(),
1495                self.action_log.clone(),
1496                cx,
1497            )
1498        };
1499
1500        // Store the card separately if it exists
1501        if let Some(card) = tool_result.card.clone() {
1502            self.tool_use
1503                .insert_tool_result_card(tool_use_id.clone(), card);
1504        }
1505
1506        cx.spawn({
1507            async move |thread: WeakEntity<Thread>, cx| {
1508                let output = tool_result.output.await;
1509
1510                thread
1511                    .update(cx, |thread, cx| {
1512                        let pending_tool_use = thread.tool_use.insert_tool_output(
1513                            tool_use_id.clone(),
1514                            tool_name,
1515                            output,
1516                            cx,
1517                        );
1518                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1519                    })
1520                    .ok();
1521            }
1522        })
1523    }
1524
1525    fn tool_finished(
1526        &mut self,
1527        tool_use_id: LanguageModelToolUseId,
1528        pending_tool_use: Option<PendingToolUse>,
1529        canceled: bool,
1530        cx: &mut Context<Self>,
1531    ) {
1532        if self.all_tools_finished() {
1533            let model_registry = LanguageModelRegistry::read_global(cx);
1534            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1535                self.attach_tool_results(cx);
1536                if !canceled {
1537                    self.send_to_model(model, RequestKind::Chat, cx);
1538                }
1539            }
1540        }
1541
1542        cx.emit(ThreadEvent::ToolFinished {
1543            tool_use_id,
1544            pending_tool_use,
1545        });
1546    }
1547
1548    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1549        // Insert a user message to contain the tool results.
1550        self.insert_user_message(
1551            // TODO: Sending up a user message without any content results in the model sending back
1552            // responses that also don't have any content. We currently don't handle this case well,
1553            // so for now we provide some text to keep the model on track.
1554            "Here are the tool results.",
1555            Vec::new(),
1556            None,
1557            cx,
1558        );
1559    }
1560
1561    /// Cancels the last pending completion, if there are any pending.
1562    ///
1563    /// Returns whether a completion was canceled.
1564    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1565        let canceled = if self.pending_completions.pop().is_some() {
1566            true
1567        } else {
1568            let mut canceled = false;
1569            for pending_tool_use in self.tool_use.cancel_pending() {
1570                canceled = true;
1571                self.tool_finished(
1572                    pending_tool_use.id.clone(),
1573                    Some(pending_tool_use),
1574                    true,
1575                    cx,
1576                );
1577            }
1578            canceled
1579        };
1580        self.finalize_pending_checkpoint(cx);
1581        canceled
1582    }
1583
1584    pub fn feedback(&self) -> Option<ThreadFeedback> {
1585        self.feedback
1586    }
1587
1588    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1589        self.message_feedback.get(&message_id).copied()
1590    }
1591
1592    pub fn report_message_feedback(
1593        &mut self,
1594        message_id: MessageId,
1595        feedback: ThreadFeedback,
1596        cx: &mut Context<Self>,
1597    ) -> Task<Result<()>> {
1598        if self.message_feedback.get(&message_id) == Some(&feedback) {
1599            return Task::ready(Ok(()));
1600        }
1601
1602        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1603        let serialized_thread = self.serialize(cx);
1604        let thread_id = self.id().clone();
1605        let client = self.project.read(cx).client();
1606
1607        let enabled_tool_names: Vec<String> = self
1608            .tools()
1609            .read(cx)
1610            .enabled_tools(cx)
1611            .iter()
1612            .map(|tool| tool.name().to_string())
1613            .collect();
1614
1615        self.message_feedback.insert(message_id, feedback);
1616
1617        cx.notify();
1618
1619        let message_content = self
1620            .message(message_id)
1621            .map(|msg| msg.to_string())
1622            .unwrap_or_default();
1623
1624        cx.background_spawn(async move {
1625            let final_project_snapshot = final_project_snapshot.await;
1626            let serialized_thread = serialized_thread.await?;
1627            let thread_data =
1628                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1629
1630            let rating = match feedback {
1631                ThreadFeedback::Positive => "positive",
1632                ThreadFeedback::Negative => "negative",
1633            };
1634            telemetry::event!(
1635                "Assistant Thread Rated",
1636                rating,
1637                thread_id,
1638                enabled_tool_names,
1639                message_id = message_id.0,
1640                message_content,
1641                thread_data,
1642                final_project_snapshot
1643            );
1644            client.telemetry().flush_events();
1645
1646            Ok(())
1647        })
1648    }
1649
1650    pub fn report_feedback(
1651        &mut self,
1652        feedback: ThreadFeedback,
1653        cx: &mut Context<Self>,
1654    ) -> Task<Result<()>> {
1655        let last_assistant_message_id = self
1656            .messages
1657            .iter()
1658            .rev()
1659            .find(|msg| msg.role == Role::Assistant)
1660            .map(|msg| msg.id);
1661
1662        if let Some(message_id) = last_assistant_message_id {
1663            self.report_message_feedback(message_id, feedback, cx)
1664        } else {
1665            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1666            let serialized_thread = self.serialize(cx);
1667            let thread_id = self.id().clone();
1668            let client = self.project.read(cx).client();
1669            self.feedback = Some(feedback);
1670            cx.notify();
1671
1672            cx.background_spawn(async move {
1673                let final_project_snapshot = final_project_snapshot.await;
1674                let serialized_thread = serialized_thread.await?;
1675                let thread_data = serde_json::to_value(serialized_thread)
1676                    .unwrap_or_else(|_| serde_json::Value::Null);
1677
1678                let rating = match feedback {
1679                    ThreadFeedback::Positive => "positive",
1680                    ThreadFeedback::Negative => "negative",
1681                };
1682                telemetry::event!(
1683                    "Assistant Thread Rated",
1684                    rating,
1685                    thread_id,
1686                    thread_data,
1687                    final_project_snapshot
1688                );
1689                client.telemetry().flush_events();
1690
1691                Ok(())
1692            })
1693        }
1694    }
1695
1696    /// Create a snapshot of the current project state including git information and unsaved buffers.
1697    fn project_snapshot(
1698        project: Entity<Project>,
1699        cx: &mut Context<Self>,
1700    ) -> Task<Arc<ProjectSnapshot>> {
1701        let git_store = project.read(cx).git_store().clone();
1702        let worktree_snapshots: Vec<_> = project
1703            .read(cx)
1704            .visible_worktrees(cx)
1705            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1706            .collect();
1707
1708        cx.spawn(async move |_, cx| {
1709            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1710
1711            let mut unsaved_buffers = Vec::new();
1712            cx.update(|app_cx| {
1713                let buffer_store = project.read(app_cx).buffer_store();
1714                for buffer_handle in buffer_store.read(app_cx).buffers() {
1715                    let buffer = buffer_handle.read(app_cx);
1716                    if buffer.is_dirty() {
1717                        if let Some(file) = buffer.file() {
1718                            let path = file.path().to_string_lossy().to_string();
1719                            unsaved_buffers.push(path);
1720                        }
1721                    }
1722                }
1723            })
1724            .ok();
1725
1726            Arc::new(ProjectSnapshot {
1727                worktree_snapshots,
1728                unsaved_buffer_paths: unsaved_buffers,
1729                timestamp: Utc::now(),
1730            })
1731        })
1732    }
1733
1734    fn worktree_snapshot(
1735        worktree: Entity<project::Worktree>,
1736        git_store: Entity<GitStore>,
1737        cx: &App,
1738    ) -> Task<WorktreeSnapshot> {
1739        cx.spawn(async move |cx| {
1740            // Get worktree path and snapshot
1741            let worktree_info = cx.update(|app_cx| {
1742                let worktree = worktree.read(app_cx);
1743                let path = worktree.abs_path().to_string_lossy().to_string();
1744                let snapshot = worktree.snapshot();
1745                (path, snapshot)
1746            });
1747
1748            let Ok((worktree_path, _snapshot)) = worktree_info else {
1749                return WorktreeSnapshot {
1750                    worktree_path: String::new(),
1751                    git_state: None,
1752                };
1753            };
1754
1755            let git_state = git_store
1756                .update(cx, |git_store, cx| {
1757                    git_store
1758                        .repositories()
1759                        .values()
1760                        .find(|repo| {
1761                            repo.read(cx)
1762                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1763                                .is_some()
1764                        })
1765                        .cloned()
1766                })
1767                .ok()
1768                .flatten()
1769                .map(|repo| {
1770                    repo.update(cx, |repo, _| {
1771                        let current_branch =
1772                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1773                        repo.send_job(None, |state, _| async move {
1774                            let RepositoryState::Local { backend, .. } = state else {
1775                                return GitState {
1776                                    remote_url: None,
1777                                    head_sha: None,
1778                                    current_branch,
1779                                    diff: None,
1780                                };
1781                            };
1782
1783                            let remote_url = backend.remote_url("origin");
1784                            let head_sha = backend.head_sha();
1785                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1786
1787                            GitState {
1788                                remote_url,
1789                                head_sha,
1790                                current_branch,
1791                                diff,
1792                            }
1793                        })
1794                    })
1795                });
1796
1797            let git_state = match git_state {
1798                Some(git_state) => match git_state.ok() {
1799                    Some(git_state) => git_state.await.ok(),
1800                    None => None,
1801                },
1802                None => None,
1803            };
1804
1805            WorktreeSnapshot {
1806                worktree_path,
1807                git_state,
1808            }
1809        })
1810    }
1811
1812    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1813        let mut markdown = Vec::new();
1814
1815        if let Some(summary) = self.summary() {
1816            writeln!(markdown, "# {summary}\n")?;
1817        };
1818
1819        for message in self.messages() {
1820            writeln!(
1821                markdown,
1822                "## {role}\n",
1823                role = match message.role {
1824                    Role::User => "User",
1825                    Role::Assistant => "Assistant",
1826                    Role::System => "System",
1827                }
1828            )?;
1829
1830            if !message.context.is_empty() {
1831                writeln!(markdown, "{}", message.context)?;
1832            }
1833
1834            for segment in &message.segments {
1835                match segment {
1836                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1837                    MessageSegment::Thinking(text) => {
1838                        writeln!(markdown, "<think>{}</think>\n", text)?
1839                    }
1840                }
1841            }
1842
1843            for tool_use in self.tool_uses_for_message(message.id, cx) {
1844                writeln!(
1845                    markdown,
1846                    "**Use Tool: {} ({})**",
1847                    tool_use.name, tool_use.id
1848                )?;
1849                writeln!(markdown, "```json")?;
1850                writeln!(
1851                    markdown,
1852                    "{}",
1853                    serde_json::to_string_pretty(&tool_use.input)?
1854                )?;
1855                writeln!(markdown, "```")?;
1856            }
1857
1858            for tool_result in self.tool_results_for_message(message.id) {
1859                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1860                if tool_result.is_error {
1861                    write!(markdown, " (Error)")?;
1862                }
1863
1864                writeln!(markdown, "**\n")?;
1865                writeln!(markdown, "{}", tool_result.content)?;
1866            }
1867        }
1868
1869        Ok(String::from_utf8_lossy(&markdown).to_string())
1870    }
1871
1872    pub fn keep_edits_in_range(
1873        &mut self,
1874        buffer: Entity<language::Buffer>,
1875        buffer_range: Range<language::Anchor>,
1876        cx: &mut Context<Self>,
1877    ) {
1878        self.action_log.update(cx, |action_log, cx| {
1879            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1880        });
1881    }
1882
1883    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1884        self.action_log
1885            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1886    }
1887
1888    pub fn reject_edits_in_ranges(
1889        &mut self,
1890        buffer: Entity<language::Buffer>,
1891        buffer_ranges: Vec<Range<language::Anchor>>,
1892        cx: &mut Context<Self>,
1893    ) -> Task<Result<()>> {
1894        self.action_log.update(cx, |action_log, cx| {
1895            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1896        })
1897    }
1898
1899    pub fn action_log(&self) -> &Entity<ActionLog> {
1900        &self.action_log
1901    }
1902
1903    pub fn project(&self) -> &Entity<Project> {
1904        &self.project
1905    }
1906
1907    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1908        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1909            return;
1910        }
1911
1912        let now = Instant::now();
1913        if let Some(last) = self.last_auto_capture_at {
1914            if now.duration_since(last).as_secs() < 10 {
1915                return;
1916            }
1917        }
1918
1919        self.last_auto_capture_at = Some(now);
1920
1921        let thread_id = self.id().clone();
1922        let github_login = self
1923            .project
1924            .read(cx)
1925            .user_store()
1926            .read(cx)
1927            .current_user()
1928            .map(|user| user.github_login.clone());
1929        let client = self.project.read(cx).client().clone();
1930        let serialize_task = self.serialize(cx);
1931
1932        cx.background_executor()
1933            .spawn(async move {
1934                if let Ok(serialized_thread) = serialize_task.await {
1935                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1936                        telemetry::event!(
1937                            "Agent Thread Auto-Captured",
1938                            thread_id = thread_id.to_string(),
1939                            thread_data = thread_data,
1940                            auto_capture_reason = "tracked_user",
1941                            github_login = github_login
1942                        );
1943
1944                        client.telemetry().flush_events();
1945                    }
1946                }
1947            })
1948            .detach();
1949    }
1950
1951    pub fn cumulative_token_usage(&self) -> TokenUsage {
1952        self.cumulative_token_usage
1953    }
1954
1955    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1956        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1957            return TotalTokenUsage::default();
1958        };
1959
1960        let max = model.model.max_token_count();
1961
1962        let index = self
1963            .messages
1964            .iter()
1965            .position(|msg| msg.id == message_id)
1966            .unwrap_or(0);
1967
1968        if index == 0 {
1969            return TotalTokenUsage { total: 0, max };
1970        }
1971
1972        let token_usage = &self
1973            .request_token_usage
1974            .get(index - 1)
1975            .cloned()
1976            .unwrap_or_default();
1977
1978        TotalTokenUsage {
1979            total: token_usage.total_tokens() as usize,
1980            max,
1981        }
1982    }
1983
1984    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1985        let model_registry = LanguageModelRegistry::read_global(cx);
1986        let Some(model) = model_registry.default_model() else {
1987            return TotalTokenUsage::default();
1988        };
1989
1990        let max = model.model.max_token_count();
1991
1992        if let Some(exceeded_error) = &self.exceeded_window_error {
1993            if model.model.id() == exceeded_error.model_id {
1994                return TotalTokenUsage {
1995                    total: exceeded_error.token_count,
1996                    max,
1997                };
1998            }
1999        }
2000
2001        let total = self
2002            .token_usage_at_last_message()
2003            .unwrap_or_default()
2004            .total_tokens() as usize;
2005
2006        TotalTokenUsage { total, max }
2007    }
2008
2009    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2010        self.request_token_usage
2011            .get(self.messages.len().saturating_sub(1))
2012            .or_else(|| self.request_token_usage.last())
2013            .cloned()
2014    }
2015
2016    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2017        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2018        self.request_token_usage
2019            .resize(self.messages.len(), placeholder);
2020
2021        if let Some(last) = self.request_token_usage.last_mut() {
2022            *last = token_usage;
2023        }
2024    }
2025
2026    pub fn deny_tool_use(
2027        &mut self,
2028        tool_use_id: LanguageModelToolUseId,
2029        tool_name: Arc<str>,
2030        cx: &mut Context<Self>,
2031    ) {
2032        let err = Err(anyhow::anyhow!(
2033            "Permission to run tool action denied by user"
2034        ));
2035
2036        self.tool_use
2037            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2038        self.tool_finished(tool_use_id.clone(), None, true, cx);
2039    }
2040}
2041
2042#[derive(Debug, Clone, Error)]
2043pub enum ThreadError {
2044    #[error("Payment required")]
2045    PaymentRequired,
2046    #[error("Max monthly spend reached")]
2047    MaxMonthlySpendReached,
2048    #[error("Model request limit reached")]
2049    ModelRequestLimitReached { plan: Plan },
2050    #[error("Message {header}: {message}")]
2051    Message {
2052        header: SharedString,
2053        message: SharedString,
2054    },
2055}
2056
2057#[derive(Debug, Clone)]
2058pub enum ThreadEvent {
2059    ShowError(ThreadError),
2060    UsageUpdated(RequestUsage),
2061    StreamedCompletion,
2062    StreamedAssistantText(MessageId, String),
2063    StreamedAssistantThinking(MessageId, String),
2064    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2065    MessageAdded(MessageId),
2066    MessageEdited(MessageId),
2067    MessageDeleted(MessageId),
2068    SummaryGenerated,
2069    SummaryChanged,
2070    UsePendingTools {
2071        tool_uses: Vec<PendingToolUse>,
2072    },
2073    ToolFinished {
2074        #[allow(unused)]
2075        tool_use_id: LanguageModelToolUseId,
2076        /// The pending tool use that corresponds to this tool.
2077        pending_tool_use: Option<PendingToolUse>,
2078    },
2079    CheckpointChanged,
2080    ToolConfirmationNeeded,
2081}
2082
2083impl EventEmitter<ThreadEvent> for Thread {}
2084
2085struct PendingCompletion {
2086    id: usize,
2087    _task: Task<()>,
2088}
2089
2090#[cfg(test)]
2091mod tests {
2092    use super::*;
2093    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2094    use assistant_settings::AssistantSettings;
2095    use context_server::ContextServerSettings;
2096    use editor::EditorSettings;
2097    use gpui::TestAppContext;
2098    use project::{FakeFs, Project};
2099    use prompt_store::PromptBuilder;
2100    use serde_json::json;
2101    use settings::{Settings, SettingsStore};
2102    use std::sync::Arc;
2103    use theme::ThemeSettings;
2104    use util::path;
2105    use workspace::Workspace;
2106
2107    #[gpui::test]
2108    async fn test_message_with_context(cx: &mut TestAppContext) {
2109        init_test_settings(cx);
2110
2111        let project = create_test_project(
2112            cx,
2113            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2114        )
2115        .await;
2116
2117        let (_workspace, _thread_store, thread, context_store) =
2118            setup_test_environment(cx, project.clone()).await;
2119
2120        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2121            .await
2122            .unwrap();
2123
2124        let context =
2125            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2126
2127        // Insert user message with context
2128        let message_id = thread.update(cx, |thread, cx| {
2129            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2130        });
2131
2132        // Check content and context in message object
2133        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2134
2135        // Use different path format strings based on platform for the test
2136        #[cfg(windows)]
2137        let path_part = r"test\code.rs";
2138        #[cfg(not(windows))]
2139        let path_part = "test/code.rs";
2140
2141        let expected_context = format!(
2142            r#"
2143<context>
2144The following items were attached by the user. You don't need to use other tools to read them.
2145
2146<files>
2147```rs {path_part}
2148fn main() {{
2149    println!("Hello, world!");
2150}}
2151```
2152</files>
2153</context>
2154"#
2155        );
2156
2157        assert_eq!(message.role, Role::User);
2158        assert_eq!(message.segments.len(), 1);
2159        assert_eq!(
2160            message.segments[0],
2161            MessageSegment::Text("Please explain this code".to_string())
2162        );
2163        assert_eq!(message.context, expected_context);
2164
2165        // Check message in request
2166        let request = thread.read_with(cx, |thread, cx| {
2167            thread.to_completion_request(RequestKind::Chat, cx)
2168        });
2169
2170        assert_eq!(request.messages.len(), 2);
2171        let expected_full_message = format!("{}Please explain this code", expected_context);
2172        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2173    }
2174
2175    #[gpui::test]
2176    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2177        init_test_settings(cx);
2178
2179        let project = create_test_project(
2180            cx,
2181            json!({
2182                "file1.rs": "fn function1() {}\n",
2183                "file2.rs": "fn function2() {}\n",
2184                "file3.rs": "fn function3() {}\n",
2185            }),
2186        )
2187        .await;
2188
2189        let (_, _thread_store, thread, context_store) =
2190            setup_test_environment(cx, project.clone()).await;
2191
2192        // Open files individually
2193        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2194            .await
2195            .unwrap();
2196        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2197            .await
2198            .unwrap();
2199        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2200            .await
2201            .unwrap();
2202
2203        // Get the context objects
2204        let contexts = context_store.update(cx, |store, _| store.context().clone());
2205        assert_eq!(contexts.len(), 3);
2206
2207        // First message with context 1
2208        let message1_id = thread.update(cx, |thread, cx| {
2209            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2210        });
2211
2212        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2213        let message2_id = thread.update(cx, |thread, cx| {
2214            thread.insert_user_message(
2215                "Message 2",
2216                vec![contexts[0].clone(), contexts[1].clone()],
2217                None,
2218                cx,
2219            )
2220        });
2221
2222        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2223        let message3_id = thread.update(cx, |thread, cx| {
2224            thread.insert_user_message(
2225                "Message 3",
2226                vec![
2227                    contexts[0].clone(),
2228                    contexts[1].clone(),
2229                    contexts[2].clone(),
2230                ],
2231                None,
2232                cx,
2233            )
2234        });
2235
2236        // Check what contexts are included in each message
2237        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2238            (
2239                thread.message(message1_id).unwrap().clone(),
2240                thread.message(message2_id).unwrap().clone(),
2241                thread.message(message3_id).unwrap().clone(),
2242            )
2243        });
2244
2245        // First message should include context 1
2246        assert!(message1.context.contains("file1.rs"));
2247
2248        // Second message should include only context 2 (not 1)
2249        assert!(!message2.context.contains("file1.rs"));
2250        assert!(message2.context.contains("file2.rs"));
2251
2252        // Third message should include only context 3 (not 1 or 2)
2253        assert!(!message3.context.contains("file1.rs"));
2254        assert!(!message3.context.contains("file2.rs"));
2255        assert!(message3.context.contains("file3.rs"));
2256
2257        // Check entire request to make sure all contexts are properly included
2258        let request = thread.read_with(cx, |thread, cx| {
2259            thread.to_completion_request(RequestKind::Chat, cx)
2260        });
2261
2262        // The request should contain all 3 messages
2263        assert_eq!(request.messages.len(), 4);
2264
2265        // Check that the contexts are properly formatted in each message
2266        assert!(request.messages[1].string_contents().contains("file1.rs"));
2267        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2268        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2269
2270        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2271        assert!(request.messages[2].string_contents().contains("file2.rs"));
2272        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2273
2274        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2275        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2276        assert!(request.messages[3].string_contents().contains("file3.rs"));
2277    }
2278
2279    #[gpui::test]
2280    async fn test_message_without_files(cx: &mut TestAppContext) {
2281        init_test_settings(cx);
2282
2283        let project = create_test_project(
2284            cx,
2285            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2286        )
2287        .await;
2288
2289        let (_, _thread_store, thread, _context_store) =
2290            setup_test_environment(cx, project.clone()).await;
2291
2292        // Insert user message without any context (empty context vector)
2293        let message_id = thread.update(cx, |thread, cx| {
2294            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2295        });
2296
2297        // Check content and context in message object
2298        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2299
2300        // Context should be empty when no files are included
2301        assert_eq!(message.role, Role::User);
2302        assert_eq!(message.segments.len(), 1);
2303        assert_eq!(
2304            message.segments[0],
2305            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2306        );
2307        assert_eq!(message.context, "");
2308
2309        // Check message in request
2310        let request = thread.read_with(cx, |thread, cx| {
2311            thread.to_completion_request(RequestKind::Chat, cx)
2312        });
2313
2314        assert_eq!(request.messages.len(), 2);
2315        assert_eq!(
2316            request.messages[1].string_contents(),
2317            "What is the best way to learn Rust?"
2318        );
2319
2320        // Add second message, also without context
2321        let message2_id = thread.update(cx, |thread, cx| {
2322            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2323        });
2324
2325        let message2 =
2326            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2327        assert_eq!(message2.context, "");
2328
2329        // Check that both messages appear in the request
2330        let request = thread.read_with(cx, |thread, cx| {
2331            thread.to_completion_request(RequestKind::Chat, cx)
2332        });
2333
2334        assert_eq!(request.messages.len(), 3);
2335        assert_eq!(
2336            request.messages[1].string_contents(),
2337            "What is the best way to learn Rust?"
2338        );
2339        assert_eq!(
2340            request.messages[2].string_contents(),
2341            "Are there any good books?"
2342        );
2343    }
2344
2345    #[gpui::test]
2346    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2347        init_test_settings(cx);
2348
2349        let project = create_test_project(
2350            cx,
2351            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2352        )
2353        .await;
2354
2355        let (_workspace, _thread_store, thread, context_store) =
2356            setup_test_environment(cx, project.clone()).await;
2357
2358        // Open buffer and add it to context
2359        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2360            .await
2361            .unwrap();
2362
2363        let context =
2364            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2365
2366        // Insert user message with the buffer as context
2367        thread.update(cx, |thread, cx| {
2368            thread.insert_user_message("Explain this code", vec![context], None, cx)
2369        });
2370
2371        // Create a request and check that it doesn't have a stale buffer warning yet
2372        let initial_request = thread.read_with(cx, |thread, cx| {
2373            thread.to_completion_request(RequestKind::Chat, cx)
2374        });
2375
2376        // Make sure we don't have a stale file warning yet
2377        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2378            msg.string_contents()
2379                .contains("These files changed since last read:")
2380        });
2381        assert!(
2382            !has_stale_warning,
2383            "Should not have stale buffer warning before buffer is modified"
2384        );
2385
2386        // Modify the buffer
2387        buffer.update(cx, |buffer, cx| {
2388            // Find a position at the end of line 1
2389            buffer.edit(
2390                [(1..1, "\n    println!(\"Added a new line\");\n")],
2391                None,
2392                cx,
2393            );
2394        });
2395
2396        // Insert another user message without context
2397        thread.update(cx, |thread, cx| {
2398            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2399        });
2400
2401        // Create a new request and check for the stale buffer warning
2402        let new_request = thread.read_with(cx, |thread, cx| {
2403            thread.to_completion_request(RequestKind::Chat, cx)
2404        });
2405
2406        // We should have a stale file warning as the last message
2407        let last_message = new_request
2408            .messages
2409            .last()
2410            .expect("Request should have messages");
2411
2412        // The last message should be the stale buffer notification
2413        assert_eq!(last_message.role, Role::User);
2414
2415        // Check the exact content of the message
2416        let expected_content = "These files changed since last read:\n- code.rs\n";
2417        assert_eq!(
2418            last_message.string_contents(),
2419            expected_content,
2420            "Last message should be exactly the stale buffer notification"
2421        );
2422    }
2423
2424    fn init_test_settings(cx: &mut TestAppContext) {
2425        cx.update(|cx| {
2426            let settings_store = SettingsStore::test(cx);
2427            cx.set_global(settings_store);
2428            language::init(cx);
2429            Project::init_settings(cx);
2430            AssistantSettings::register(cx);
2431            thread_store::init(cx);
2432            workspace::init_settings(cx);
2433            ThemeSettings::register(cx);
2434            ContextServerSettings::register(cx);
2435            EditorSettings::register(cx);
2436        });
2437    }
2438
2439    // Helper to create a test project with test files
2440    async fn create_test_project(
2441        cx: &mut TestAppContext,
2442        files: serde_json::Value,
2443    ) -> Entity<Project> {
2444        let fs = FakeFs::new(cx.executor());
2445        fs.insert_tree(path!("/test"), files).await;
2446        Project::test(fs, [path!("/test").as_ref()], cx).await
2447    }
2448
2449    async fn setup_test_environment(
2450        cx: &mut TestAppContext,
2451        project: Entity<Project>,
2452    ) -> (
2453        Entity<Workspace>,
2454        Entity<ThreadStore>,
2455        Entity<Thread>,
2456        Entity<ContextStore>,
2457    ) {
2458        let (workspace, cx) =
2459            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2460
2461        let thread_store = cx
2462            .update(|_, cx| {
2463                ThreadStore::load(
2464                    project.clone(),
2465                    cx.new(|_| ToolWorkingSet::default()),
2466                    Arc::new(PromptBuilder::new(None).unwrap()),
2467                    cx,
2468                )
2469            })
2470            .await;
2471
2472        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2473        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2474
2475        (workspace, thread_store, thread, context_store)
2476    }
2477
2478    async fn add_file_to_context(
2479        project: &Entity<Project>,
2480        context_store: &Entity<ContextStore>,
2481        path: &str,
2482        cx: &mut TestAppContext,
2483    ) -> Result<Entity<language::Buffer>> {
2484        let buffer_path = project
2485            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2486            .unwrap();
2487
2488        let buffer = project
2489            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2490            .await
2491            .unwrap();
2492
2493        context_store
2494            .update(cx, |store, cx| {
2495                store.add_file_from_buffer(buffer.clone(), cx)
2496            })
2497            .await?;
2498
2499        Ok(buffer)
2500    }
2501}