thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, RequestUsage, Role, StopReason,
  23    TokenUsage,
  24};
  25use project::Project;
  26use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  27use prompt_store::PromptBuilder;
  28use proto::Plan;
  29use schemars::JsonSchema;
  30use serde::{Deserialize, Serialize};
  31use settings::Settings;
  32use thiserror::Error;
  33use util::{ResultExt as _, TryFutureExt as _, post_inc};
  34use uuid::Uuid;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  42
  43#[derive(Debug, Clone, Copy)]
  44pub enum RequestKind {
  45    Chat,
  46    /// Used when summarizing a thread.
  47    Summarize,
  48}
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  74pub struct MessageId(pub(crate) usize);
  75
  76impl MessageId {
  77    fn post_inc(&mut self) -> Self {
  78        Self(post_inc(&mut self.0))
  79    }
  80}
  81
  82/// A message in a [`Thread`].
  83#[derive(Debug, Clone)]
  84pub struct Message {
  85    pub id: MessageId,
  86    pub role: Role,
  87    pub segments: Vec<MessageSegment>,
  88    pub context: String,
  89}
  90
  91impl Message {
  92    /// Returns whether the message contains any meaningful text that should be displayed
  93    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  94    pub fn should_display_content(&self) -> bool {
  95        self.segments.iter().all(|segment| segment.should_display())
  96    }
  97
  98    pub fn push_thinking(&mut self, text: &str) {
  99        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
 100            segment.push_str(text);
 101        } else {
 102            self.segments
 103                .push(MessageSegment::Thinking(text.to_string()));
 104        }
 105    }
 106
 107    pub fn push_text(&mut self, text: &str) {
 108        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 109            segment.push_str(text);
 110        } else {
 111            self.segments.push(MessageSegment::Text(text.to_string()));
 112        }
 113    }
 114
 115    pub fn to_string(&self) -> String {
 116        let mut result = String::new();
 117
 118        if !self.context.is_empty() {
 119            result.push_str(&self.context);
 120        }
 121
 122        for segment in &self.segments {
 123            match segment {
 124                MessageSegment::Text(text) => result.push_str(text),
 125                MessageSegment::Thinking(text) => {
 126                    result.push_str("<think>");
 127                    result.push_str(text);
 128                    result.push_str("</think>");
 129                }
 130            }
 131        }
 132
 133        result
 134    }
 135}
 136
 137#[derive(Debug, Clone, PartialEq, Eq)]
 138pub enum MessageSegment {
 139    Text(String),
 140    Thinking(String),
 141}
 142
 143impl MessageSegment {
 144    pub fn text_mut(&mut self) -> &mut String {
 145        match self {
 146            Self::Text(text) => text,
 147            Self::Thinking(text) => text,
 148        }
 149    }
 150
 151    pub fn should_display(&self) -> bool {
 152        // We add USING_TOOL_MARKER when making a request that includes tool uses
 153        // without non-whitespace text around them, and this can cause the model
 154        // to mimic the pattern, so we consider those segments not displayable.
 155        match self {
 156            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 157            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 158        }
 159    }
 160}
 161
 162#[derive(Debug, Clone, Serialize, Deserialize)]
 163pub struct ProjectSnapshot {
 164    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 165    pub unsaved_buffer_paths: Vec<String>,
 166    pub timestamp: DateTime<Utc>,
 167}
 168
 169#[derive(Debug, Clone, Serialize, Deserialize)]
 170pub struct WorktreeSnapshot {
 171    pub worktree_path: String,
 172    pub git_state: Option<GitState>,
 173}
 174
 175#[derive(Debug, Clone, Serialize, Deserialize)]
 176pub struct GitState {
 177    pub remote_url: Option<String>,
 178    pub head_sha: Option<String>,
 179    pub current_branch: Option<String>,
 180    pub diff: Option<String>,
 181}
 182
 183#[derive(Clone)]
 184pub struct ThreadCheckpoint {
 185    message_id: MessageId,
 186    git_checkpoint: GitStoreCheckpoint,
 187}
 188
 189#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 190pub enum ThreadFeedback {
 191    Positive,
 192    Negative,
 193}
 194
 195pub enum LastRestoreCheckpoint {
 196    Pending {
 197        message_id: MessageId,
 198    },
 199    Error {
 200        message_id: MessageId,
 201        error: String,
 202    },
 203}
 204
 205impl LastRestoreCheckpoint {
 206    pub fn message_id(&self) -> MessageId {
 207        match self {
 208            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 209            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 210        }
 211    }
 212}
 213
 214#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 215pub enum DetailedSummaryState {
 216    #[default]
 217    NotGenerated,
 218    Generating {
 219        message_id: MessageId,
 220    },
 221    Generated {
 222        text: SharedString,
 223        message_id: MessageId,
 224    },
 225}
 226
 227#[derive(Default)]
 228pub struct TotalTokenUsage {
 229    pub total: usize,
 230    pub max: usize,
 231}
 232
 233impl TotalTokenUsage {
 234    pub fn ratio(&self) -> TokenUsageRatio {
 235        #[cfg(debug_assertions)]
 236        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 237            .unwrap_or("0.8".to_string())
 238            .parse()
 239            .unwrap();
 240        #[cfg(not(debug_assertions))]
 241        let warning_threshold: f32 = 0.8;
 242
 243        if self.total >= self.max {
 244            TokenUsageRatio::Exceeded
 245        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 246            TokenUsageRatio::Warning
 247        } else {
 248            TokenUsageRatio::Normal
 249        }
 250    }
 251
 252    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 253        TotalTokenUsage {
 254            total: self.total + tokens,
 255            max: self.max,
 256        }
 257    }
 258}
 259
 260#[derive(Debug, Default, PartialEq, Eq)]
 261pub enum TokenUsageRatio {
 262    #[default]
 263    Normal,
 264    Warning,
 265    Exceeded,
 266}
 267
 268/// A thread of conversation with the LLM.
 269pub struct Thread {
 270    id: ThreadId,
 271    updated_at: DateTime<Utc>,
 272    summary: Option<SharedString>,
 273    pending_summary: Task<Option<()>>,
 274    detailed_summary_state: DetailedSummaryState,
 275    messages: Vec<Message>,
 276    next_message_id: MessageId,
 277    context: BTreeMap<ContextId, AssistantContext>,
 278    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 279    project_context: SharedProjectContext,
 280    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 281    completion_count: usize,
 282    pending_completions: Vec<PendingCompletion>,
 283    project: Entity<Project>,
 284    prompt_builder: Arc<PromptBuilder>,
 285    tools: Entity<ToolWorkingSet>,
 286    tool_use: ToolUseState,
 287    action_log: Entity<ActionLog>,
 288    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 289    pending_checkpoint: Option<ThreadCheckpoint>,
 290    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 291    request_token_usage: Vec<TokenUsage>,
 292    cumulative_token_usage: TokenUsage,
 293    exceeded_window_error: Option<ExceededWindowError>,
 294    feedback: Option<ThreadFeedback>,
 295    message_feedback: HashMap<MessageId, ThreadFeedback>,
 296    last_auto_capture_at: Option<Instant>,
 297}
 298
 299#[derive(Debug, Clone, Serialize, Deserialize)]
 300pub struct ExceededWindowError {
 301    /// Model used when last message exceeded context window
 302    model_id: LanguageModelId,
 303    /// Token count including last message
 304    token_count: usize,
 305}
 306
 307impl Thread {
 308    pub fn new(
 309        project: Entity<Project>,
 310        tools: Entity<ToolWorkingSet>,
 311        prompt_builder: Arc<PromptBuilder>,
 312        system_prompt: SharedProjectContext,
 313        cx: &mut Context<Self>,
 314    ) -> Self {
 315        Self {
 316            id: ThreadId::new(),
 317            updated_at: Utc::now(),
 318            summary: None,
 319            pending_summary: Task::ready(None),
 320            detailed_summary_state: DetailedSummaryState::NotGenerated,
 321            messages: Vec::new(),
 322            next_message_id: MessageId(0),
 323            context: BTreeMap::default(),
 324            context_by_message: HashMap::default(),
 325            project_context: system_prompt,
 326            checkpoints_by_message: HashMap::default(),
 327            completion_count: 0,
 328            pending_completions: Vec::new(),
 329            project: project.clone(),
 330            prompt_builder,
 331            tools: tools.clone(),
 332            last_restore_checkpoint: None,
 333            pending_checkpoint: None,
 334            tool_use: ToolUseState::new(tools.clone()),
 335            action_log: cx.new(|_| ActionLog::new(project.clone())),
 336            initial_project_snapshot: {
 337                let project_snapshot = Self::project_snapshot(project, cx);
 338                cx.foreground_executor()
 339                    .spawn(async move { Some(project_snapshot.await) })
 340                    .shared()
 341            },
 342            request_token_usage: Vec::new(),
 343            cumulative_token_usage: TokenUsage::default(),
 344            exceeded_window_error: None,
 345            feedback: None,
 346            message_feedback: HashMap::default(),
 347            last_auto_capture_at: None,
 348        }
 349    }
 350
 351    pub fn deserialize(
 352        id: ThreadId,
 353        serialized: SerializedThread,
 354        project: Entity<Project>,
 355        tools: Entity<ToolWorkingSet>,
 356        prompt_builder: Arc<PromptBuilder>,
 357        project_context: SharedProjectContext,
 358        cx: &mut Context<Self>,
 359    ) -> Self {
 360        let next_message_id = MessageId(
 361            serialized
 362                .messages
 363                .last()
 364                .map(|message| message.id.0 + 1)
 365                .unwrap_or(0),
 366        );
 367        let tool_use =
 368            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 369
 370        Self {
 371            id,
 372            updated_at: serialized.updated_at,
 373            summary: Some(serialized.summary),
 374            pending_summary: Task::ready(None),
 375            detailed_summary_state: serialized.detailed_summary_state,
 376            messages: serialized
 377                .messages
 378                .into_iter()
 379                .map(|message| Message {
 380                    id: message.id,
 381                    role: message.role,
 382                    segments: message
 383                        .segments
 384                        .into_iter()
 385                        .map(|segment| match segment {
 386                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 387                            SerializedMessageSegment::Thinking { text } => {
 388                                MessageSegment::Thinking(text)
 389                            }
 390                        })
 391                        .collect(),
 392                    context: message.context,
 393                })
 394                .collect(),
 395            next_message_id,
 396            context: BTreeMap::default(),
 397            context_by_message: HashMap::default(),
 398            project_context,
 399            checkpoints_by_message: HashMap::default(),
 400            completion_count: 0,
 401            pending_completions: Vec::new(),
 402            last_restore_checkpoint: None,
 403            pending_checkpoint: None,
 404            project: project.clone(),
 405            prompt_builder,
 406            tools,
 407            tool_use,
 408            action_log: cx.new(|_| ActionLog::new(project)),
 409            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 410            request_token_usage: serialized.request_token_usage,
 411            cumulative_token_usage: serialized.cumulative_token_usage,
 412            exceeded_window_error: None,
 413            feedback: None,
 414            message_feedback: HashMap::default(),
 415            last_auto_capture_at: None,
 416        }
 417    }
 418
 419    pub fn id(&self) -> &ThreadId {
 420        &self.id
 421    }
 422
 423    pub fn is_empty(&self) -> bool {
 424        self.messages.is_empty()
 425    }
 426
 427    pub fn updated_at(&self) -> DateTime<Utc> {
 428        self.updated_at
 429    }
 430
 431    pub fn touch_updated_at(&mut self) {
 432        self.updated_at = Utc::now();
 433    }
 434
 435    pub fn summary(&self) -> Option<SharedString> {
 436        self.summary.clone()
 437    }
 438
 439    pub fn project_context(&self) -> SharedProjectContext {
 440        self.project_context.clone()
 441    }
 442
 443    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 444
 445    pub fn summary_or_default(&self) -> SharedString {
 446        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 447    }
 448
 449    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 450        let Some(current_summary) = &self.summary else {
 451            // Don't allow setting summary until generated
 452            return;
 453        };
 454
 455        let mut new_summary = new_summary.into();
 456
 457        if new_summary.is_empty() {
 458            new_summary = Self::DEFAULT_SUMMARY;
 459        }
 460
 461        if current_summary != &new_summary {
 462            self.summary = Some(new_summary);
 463            cx.emit(ThreadEvent::SummaryChanged);
 464        }
 465    }
 466
 467    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 468        self.latest_detailed_summary()
 469            .unwrap_or_else(|| self.text().into())
 470    }
 471
 472    fn latest_detailed_summary(&self) -> Option<SharedString> {
 473        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 474            Some(text.clone())
 475        } else {
 476            None
 477        }
 478    }
 479
 480    pub fn message(&self, id: MessageId) -> Option<&Message> {
 481        self.messages.iter().find(|message| message.id == id)
 482    }
 483
 484    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 485        self.messages.iter()
 486    }
 487
 488    pub fn is_generating(&self) -> bool {
 489        !self.pending_completions.is_empty() || !self.all_tools_finished()
 490    }
 491
 492    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 493        &self.tools
 494    }
 495
 496    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 497        self.tool_use
 498            .pending_tool_uses()
 499            .into_iter()
 500            .find(|tool_use| &tool_use.id == id)
 501    }
 502
 503    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 504        self.tool_use
 505            .pending_tool_uses()
 506            .into_iter()
 507            .filter(|tool_use| tool_use.status.needs_confirmation())
 508    }
 509
 510    pub fn has_pending_tool_uses(&self) -> bool {
 511        !self.tool_use.pending_tool_uses().is_empty()
 512    }
 513
 514    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 515        self.checkpoints_by_message.get(&id).cloned()
 516    }
 517
 518    pub fn restore_checkpoint(
 519        &mut self,
 520        checkpoint: ThreadCheckpoint,
 521        cx: &mut Context<Self>,
 522    ) -> Task<Result<()>> {
 523        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 524            message_id: checkpoint.message_id,
 525        });
 526        cx.emit(ThreadEvent::CheckpointChanged);
 527        cx.notify();
 528
 529        let git_store = self.project().read(cx).git_store().clone();
 530        let restore = git_store.update(cx, |git_store, cx| {
 531            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 532        });
 533
 534        cx.spawn(async move |this, cx| {
 535            let result = restore.await;
 536            this.update(cx, |this, cx| {
 537                if let Err(err) = result.as_ref() {
 538                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 539                        message_id: checkpoint.message_id,
 540                        error: err.to_string(),
 541                    });
 542                } else {
 543                    this.truncate(checkpoint.message_id, cx);
 544                    this.last_restore_checkpoint = None;
 545                }
 546                this.pending_checkpoint = None;
 547                cx.emit(ThreadEvent::CheckpointChanged);
 548                cx.notify();
 549            })?;
 550            result
 551        })
 552    }
 553
 554    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 555        let pending_checkpoint = if self.is_generating() {
 556            return;
 557        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 558            checkpoint
 559        } else {
 560            return;
 561        };
 562
 563        let git_store = self.project.read(cx).git_store().clone();
 564        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 565        cx.spawn(async move |this, cx| match final_checkpoint.await {
 566            Ok(final_checkpoint) => {
 567                let equal = git_store
 568                    .update(cx, |store, cx| {
 569                        store.compare_checkpoints(
 570                            pending_checkpoint.git_checkpoint.clone(),
 571                            final_checkpoint.clone(),
 572                            cx,
 573                        )
 574                    })?
 575                    .await
 576                    .unwrap_or(false);
 577
 578                if equal {
 579                    git_store
 580                        .update(cx, |store, cx| {
 581                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 582                        })?
 583                        .detach();
 584                } else {
 585                    this.update(cx, |this, cx| {
 586                        this.insert_checkpoint(pending_checkpoint, cx)
 587                    })?;
 588                }
 589
 590                git_store
 591                    .update(cx, |store, cx| {
 592                        store.delete_checkpoint(final_checkpoint, cx)
 593                    })?
 594                    .detach();
 595
 596                Ok(())
 597            }
 598            Err(_) => this.update(cx, |this, cx| {
 599                this.insert_checkpoint(pending_checkpoint, cx)
 600            }),
 601        })
 602        .detach();
 603    }
 604
 605    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 606        self.checkpoints_by_message
 607            .insert(checkpoint.message_id, checkpoint);
 608        cx.emit(ThreadEvent::CheckpointChanged);
 609        cx.notify();
 610    }
 611
 612    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 613        self.last_restore_checkpoint.as_ref()
 614    }
 615
 616    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 617        let Some(message_ix) = self
 618            .messages
 619            .iter()
 620            .rposition(|message| message.id == message_id)
 621        else {
 622            return;
 623        };
 624        for deleted_message in self.messages.drain(message_ix..) {
 625            self.context_by_message.remove(&deleted_message.id);
 626            self.checkpoints_by_message.remove(&deleted_message.id);
 627        }
 628        cx.notify();
 629    }
 630
 631    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 632        self.context_by_message
 633            .get(&id)
 634            .into_iter()
 635            .flat_map(|context| {
 636                context
 637                    .iter()
 638                    .filter_map(|context_id| self.context.get(&context_id))
 639            })
 640    }
 641
 642    /// Returns whether all of the tool uses have finished running.
 643    pub fn all_tools_finished(&self) -> bool {
 644        // If the only pending tool uses left are the ones with errors, then
 645        // that means that we've finished running all of the pending tools.
 646        self.tool_use
 647            .pending_tool_uses()
 648            .iter()
 649            .all(|tool_use| tool_use.status.is_error())
 650    }
 651
 652    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 653        self.tool_use.tool_uses_for_message(id, cx)
 654    }
 655
 656    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 657        self.tool_use.tool_results_for_message(id)
 658    }
 659
 660    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 661        self.tool_use.tool_result(id)
 662    }
 663
 664    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 665        Some(&self.tool_use.tool_result(id)?.content)
 666    }
 667
 668    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 669        self.tool_use.tool_result_card(id).cloned()
 670    }
 671
 672    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 673        self.tool_use.message_has_tool_results(message_id)
 674    }
 675
 676    /// Filter out contexts that have already been included in previous messages
 677    pub fn filter_new_context<'a>(
 678        &self,
 679        context: impl Iterator<Item = &'a AssistantContext>,
 680    ) -> impl Iterator<Item = &'a AssistantContext> {
 681        context.filter(|ctx| self.is_context_new(ctx))
 682    }
 683
 684    fn is_context_new(&self, context: &AssistantContext) -> bool {
 685        !self.context.contains_key(&context.id())
 686    }
 687
 688    pub fn insert_user_message(
 689        &mut self,
 690        text: impl Into<String>,
 691        context: Vec<AssistantContext>,
 692        git_checkpoint: Option<GitStoreCheckpoint>,
 693        cx: &mut Context<Self>,
 694    ) -> MessageId {
 695        let text = text.into();
 696
 697        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 698
 699        let new_context: Vec<_> = context
 700            .into_iter()
 701            .filter(|ctx| self.is_context_new(ctx))
 702            .collect();
 703
 704        if !new_context.is_empty() {
 705            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 706                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 707                    message.context = context_string;
 708                }
 709            }
 710
 711            self.action_log.update(cx, |log, cx| {
 712                // Track all buffers added as context
 713                for ctx in &new_context {
 714                    match ctx {
 715                        AssistantContext::File(file_ctx) => {
 716                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 717                        }
 718                        AssistantContext::Directory(dir_ctx) => {
 719                            for context_buffer in &dir_ctx.context_buffers {
 720                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 721                            }
 722                        }
 723                        AssistantContext::Symbol(symbol_ctx) => {
 724                            log.buffer_added_as_context(
 725                                symbol_ctx.context_symbol.buffer.clone(),
 726                                cx,
 727                            );
 728                        }
 729                        AssistantContext::Excerpt(excerpt_context) => {
 730                            log.buffer_added_as_context(
 731                                excerpt_context.context_buffer.buffer.clone(),
 732                                cx,
 733                            );
 734                        }
 735                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 736                    }
 737                }
 738            });
 739        }
 740
 741        let context_ids = new_context
 742            .iter()
 743            .map(|context| context.id())
 744            .collect::<Vec<_>>();
 745        self.context.extend(
 746            new_context
 747                .into_iter()
 748                .map(|context| (context.id(), context)),
 749        );
 750        self.context_by_message.insert(message_id, context_ids);
 751
 752        if let Some(git_checkpoint) = git_checkpoint {
 753            self.pending_checkpoint = Some(ThreadCheckpoint {
 754                message_id,
 755                git_checkpoint,
 756            });
 757        }
 758
 759        self.auto_capture_telemetry(cx);
 760
 761        message_id
 762    }
 763
 764    pub fn insert_message(
 765        &mut self,
 766        role: Role,
 767        segments: Vec<MessageSegment>,
 768        cx: &mut Context<Self>,
 769    ) -> MessageId {
 770        let id = self.next_message_id.post_inc();
 771        self.messages.push(Message {
 772            id,
 773            role,
 774            segments,
 775            context: String::new(),
 776        });
 777        self.touch_updated_at();
 778        cx.emit(ThreadEvent::MessageAdded(id));
 779        id
 780    }
 781
 782    pub fn edit_message(
 783        &mut self,
 784        id: MessageId,
 785        new_role: Role,
 786        new_segments: Vec<MessageSegment>,
 787        cx: &mut Context<Self>,
 788    ) -> bool {
 789        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 790            return false;
 791        };
 792        message.role = new_role;
 793        message.segments = new_segments;
 794        self.touch_updated_at();
 795        cx.emit(ThreadEvent::MessageEdited(id));
 796        true
 797    }
 798
 799    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 800        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 801            return false;
 802        };
 803        self.messages.remove(index);
 804        self.context_by_message.remove(&id);
 805        self.touch_updated_at();
 806        cx.emit(ThreadEvent::MessageDeleted(id));
 807        true
 808    }
 809
 810    /// Returns the representation of this [`Thread`] in a textual form.
 811    ///
 812    /// This is the representation we use when attaching a thread as context to another thread.
 813    pub fn text(&self) -> String {
 814        let mut text = String::new();
 815
 816        for message in &self.messages {
 817            text.push_str(match message.role {
 818                language_model::Role::User => "User:",
 819                language_model::Role::Assistant => "Assistant:",
 820                language_model::Role::System => "System:",
 821            });
 822            text.push('\n');
 823
 824            for segment in &message.segments {
 825                match segment {
 826                    MessageSegment::Text(content) => text.push_str(content),
 827                    MessageSegment::Thinking(content) => {
 828                        text.push_str(&format!("<think>{}</think>", content))
 829                    }
 830                }
 831            }
 832            text.push('\n');
 833        }
 834
 835        text
 836    }
 837
 838    /// Serializes this thread into a format for storage or telemetry.
 839    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 840        let initial_project_snapshot = self.initial_project_snapshot.clone();
 841        cx.spawn(async move |this, cx| {
 842            let initial_project_snapshot = initial_project_snapshot.await;
 843            this.read_with(cx, |this, cx| SerializedThread {
 844                version: SerializedThread::VERSION.to_string(),
 845                summary: this.summary_or_default(),
 846                updated_at: this.updated_at(),
 847                messages: this
 848                    .messages()
 849                    .map(|message| SerializedMessage {
 850                        id: message.id,
 851                        role: message.role,
 852                        segments: message
 853                            .segments
 854                            .iter()
 855                            .map(|segment| match segment {
 856                                MessageSegment::Text(text) => {
 857                                    SerializedMessageSegment::Text { text: text.clone() }
 858                                }
 859                                MessageSegment::Thinking(text) => {
 860                                    SerializedMessageSegment::Thinking { text: text.clone() }
 861                                }
 862                            })
 863                            .collect(),
 864                        tool_uses: this
 865                            .tool_uses_for_message(message.id, cx)
 866                            .into_iter()
 867                            .map(|tool_use| SerializedToolUse {
 868                                id: tool_use.id,
 869                                name: tool_use.name,
 870                                input: tool_use.input,
 871                            })
 872                            .collect(),
 873                        tool_results: this
 874                            .tool_results_for_message(message.id)
 875                            .into_iter()
 876                            .map(|tool_result| SerializedToolResult {
 877                                tool_use_id: tool_result.tool_use_id.clone(),
 878                                is_error: tool_result.is_error,
 879                                content: tool_result.content.clone(),
 880                            })
 881                            .collect(),
 882                        context: message.context.clone(),
 883                    })
 884                    .collect(),
 885                initial_project_snapshot,
 886                cumulative_token_usage: this.cumulative_token_usage,
 887                request_token_usage: this.request_token_usage.clone(),
 888                detailed_summary_state: this.detailed_summary_state.clone(),
 889                exceeded_window_error: this.exceeded_window_error.clone(),
 890            })
 891        })
 892    }
 893
 894    pub fn send_to_model(
 895        &mut self,
 896        model: Arc<dyn LanguageModel>,
 897        request_kind: RequestKind,
 898        cx: &mut Context<Self>,
 899    ) {
 900        let mut request = self.to_completion_request(request_kind, cx);
 901        if model.supports_tools() {
 902            request.tools = {
 903                let mut tools = Vec::new();
 904                tools.extend(
 905                    self.tools()
 906                        .read(cx)
 907                        .enabled_tools(cx)
 908                        .into_iter()
 909                        .filter_map(|tool| {
 910                            // Skip tools that cannot be supported
 911                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 912                            Some(LanguageModelRequestTool {
 913                                name: tool.name(),
 914                                description: tool.description(),
 915                                input_schema,
 916                            })
 917                        }),
 918                );
 919
 920                tools
 921            };
 922        }
 923
 924        self.stream_completion(request, model, cx);
 925    }
 926
 927    pub fn used_tools_since_last_user_message(&self) -> bool {
 928        for message in self.messages.iter().rev() {
 929            if self.tool_use.message_has_tool_results(message.id) {
 930                return true;
 931            } else if message.role == Role::User {
 932                return false;
 933            }
 934        }
 935
 936        false
 937    }
 938
 939    pub fn to_completion_request(
 940        &self,
 941        request_kind: RequestKind,
 942        cx: &mut Context<Self>,
 943    ) -> LanguageModelRequest {
 944        let mut request = LanguageModelRequest {
 945            messages: vec![],
 946            tools: Vec::new(),
 947            stop: Vec::new(),
 948            temperature: None,
 949        };
 950
 951        if let Some(project_context) = self.project_context.borrow().as_ref() {
 952            match self
 953                .prompt_builder
 954                .generate_assistant_system_prompt(project_context)
 955            {
 956                Err(err) => {
 957                    let message = format!("{err:?}").into();
 958                    log::error!("{message}");
 959                    cx.emit(ThreadEvent::ShowError(ThreadError::Message {
 960                        header: "Error generating system prompt".into(),
 961                        message,
 962                    }));
 963                }
 964                Ok(system_prompt) => {
 965                    request.messages.push(LanguageModelRequestMessage {
 966                        role: Role::System,
 967                        content: vec![MessageContent::Text(system_prompt)],
 968                        cache: true,
 969                    });
 970                }
 971            }
 972        } else {
 973            let message = "Context for system prompt unexpectedly not ready.".into();
 974            log::error!("{message}");
 975            cx.emit(ThreadEvent::ShowError(ThreadError::Message {
 976                header: "Error generating system prompt".into(),
 977                message,
 978            }));
 979        }
 980
 981        for message in &self.messages {
 982            let mut request_message = LanguageModelRequestMessage {
 983                role: message.role,
 984                content: Vec::new(),
 985                cache: false,
 986            };
 987
 988            match request_kind {
 989                RequestKind::Chat => {
 990                    self.tool_use
 991                        .attach_tool_results(message.id, &mut request_message);
 992                }
 993                RequestKind::Summarize => {
 994                    // We don't care about tool use during summarization.
 995                    if self.tool_use.message_has_tool_results(message.id) {
 996                        continue;
 997                    }
 998                }
 999            }
1000
1001            if !message.segments.is_empty() {
1002                request_message
1003                    .content
1004                    .push(MessageContent::Text(message.to_string()));
1005            }
1006
1007            match request_kind {
1008                RequestKind::Chat => {
1009                    self.tool_use
1010                        .attach_tool_uses(message.id, &mut request_message);
1011                }
1012                RequestKind::Summarize => {
1013                    // We don't care about tool use during summarization.
1014                }
1015            };
1016
1017            request.messages.push(request_message);
1018        }
1019
1020        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1021        if let Some(last) = request.messages.last_mut() {
1022            last.cache = true;
1023        }
1024
1025        self.attached_tracked_files_state(&mut request.messages, cx);
1026
1027        request
1028    }
1029
1030    fn attached_tracked_files_state(
1031        &self,
1032        messages: &mut Vec<LanguageModelRequestMessage>,
1033        cx: &App,
1034    ) {
1035        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1036
1037        let mut stale_message = String::new();
1038
1039        let action_log = self.action_log.read(cx);
1040
1041        for stale_file in action_log.stale_buffers(cx) {
1042            let Some(file) = stale_file.read(cx).file() else {
1043                continue;
1044            };
1045
1046            if stale_message.is_empty() {
1047                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1048            }
1049
1050            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1051        }
1052
1053        let mut content = Vec::with_capacity(2);
1054
1055        if !stale_message.is_empty() {
1056            content.push(stale_message.into());
1057        }
1058
1059        if action_log.has_edited_files_since_project_diagnostics_check() {
1060            content.push(
1061                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1062                and fix all errors AND warnings you introduced! \
1063                DO NOT mention you're going to do this until you're done."
1064                    .into(),
1065            );
1066        }
1067
1068        if !content.is_empty() {
1069            let context_message = LanguageModelRequestMessage {
1070                role: Role::User,
1071                content,
1072                cache: false,
1073            };
1074
1075            messages.push(context_message);
1076        }
1077    }
1078
1079    pub fn stream_completion(
1080        &mut self,
1081        request: LanguageModelRequest,
1082        model: Arc<dyn LanguageModel>,
1083        cx: &mut Context<Self>,
1084    ) {
1085        let pending_completion_id = post_inc(&mut self.completion_count);
1086        let task = cx.spawn(async move |thread, cx| {
1087            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1088            let initial_token_usage =
1089                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1090            let stream_completion = async {
1091                let (mut events, usage) = stream_completion_future.await?;
1092                let mut stop_reason = StopReason::EndTurn;
1093                let mut current_token_usage = TokenUsage::default();
1094
1095                if let Some(usage) = usage {
1096                    thread
1097                        .update(cx, |_thread, cx| {
1098                            cx.emit(ThreadEvent::UsageUpdated(usage));
1099                        })
1100                        .ok();
1101                }
1102
1103                while let Some(event) = events.next().await {
1104                    let event = event?;
1105
1106                    thread.update(cx, |thread, cx| {
1107                        match event {
1108                            LanguageModelCompletionEvent::StartMessage { .. } => {
1109                                thread.insert_message(
1110                                    Role::Assistant,
1111                                    vec![MessageSegment::Text(String::new())],
1112                                    cx,
1113                                );
1114                            }
1115                            LanguageModelCompletionEvent::Stop(reason) => {
1116                                stop_reason = reason;
1117                            }
1118                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1119                                thread.update_token_usage_at_last_message(token_usage);
1120                                thread.cumulative_token_usage = thread.cumulative_token_usage
1121                                    + token_usage
1122                                    - current_token_usage;
1123                                current_token_usage = token_usage;
1124                            }
1125                            LanguageModelCompletionEvent::Text(chunk) => {
1126                                if let Some(last_message) = thread.messages.last_mut() {
1127                                    if last_message.role == Role::Assistant {
1128                                        last_message.push_text(&chunk);
1129                                        cx.emit(ThreadEvent::StreamedAssistantText(
1130                                            last_message.id,
1131                                            chunk,
1132                                        ));
1133                                    } else {
1134                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1135                                        // of a new Assistant response.
1136                                        //
1137                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1138                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1139                                        thread.insert_message(
1140                                            Role::Assistant,
1141                                            vec![MessageSegment::Text(chunk.to_string())],
1142                                            cx,
1143                                        );
1144                                    };
1145                                }
1146                            }
1147                            LanguageModelCompletionEvent::Thinking(chunk) => {
1148                                if let Some(last_message) = thread.messages.last_mut() {
1149                                    if last_message.role == Role::Assistant {
1150                                        last_message.push_thinking(&chunk);
1151                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1152                                            last_message.id,
1153                                            chunk,
1154                                        ));
1155                                    } else {
1156                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1157                                        // of a new Assistant response.
1158                                        //
1159                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1160                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1161                                        thread.insert_message(
1162                                            Role::Assistant,
1163                                            vec![MessageSegment::Thinking(chunk.to_string())],
1164                                            cx,
1165                                        );
1166                                    };
1167                                }
1168                            }
1169                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1170                                let last_assistant_message_id = thread
1171                                    .messages
1172                                    .iter_mut()
1173                                    .rfind(|message| message.role == Role::Assistant)
1174                                    .map(|message| message.id)
1175                                    .unwrap_or_else(|| {
1176                                        thread.insert_message(Role::Assistant, vec![], cx)
1177                                    });
1178
1179                                thread.tool_use.request_tool_use(
1180                                    last_assistant_message_id,
1181                                    tool_use,
1182                                    cx,
1183                                );
1184                            }
1185                        }
1186
1187                        thread.touch_updated_at();
1188                        cx.emit(ThreadEvent::StreamedCompletion);
1189                        cx.notify();
1190
1191                        thread.auto_capture_telemetry(cx);
1192                    })?;
1193
1194                    smol::future::yield_now().await;
1195                }
1196
1197                thread.update(cx, |thread, cx| {
1198                    thread
1199                        .pending_completions
1200                        .retain(|completion| completion.id != pending_completion_id);
1201
1202                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1203                        thread.summarize(cx);
1204                    }
1205                })?;
1206
1207                anyhow::Ok(stop_reason)
1208            };
1209
1210            let result = stream_completion.await;
1211
1212            thread
1213                .update(cx, |thread, cx| {
1214                    thread.finalize_pending_checkpoint(cx);
1215                    match result.as_ref() {
1216                        Ok(stop_reason) => match stop_reason {
1217                            StopReason::ToolUse => {
1218                                let tool_uses = thread.use_pending_tools(cx);
1219                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1220                            }
1221                            StopReason::EndTurn => {}
1222                            StopReason::MaxTokens => {}
1223                        },
1224                        Err(error) => {
1225                            if error.is::<PaymentRequiredError>() {
1226                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1227                            } else if error.is::<MaxMonthlySpendReachedError>() {
1228                                cx.emit(ThreadEvent::ShowError(
1229                                    ThreadError::MaxMonthlySpendReached,
1230                                ));
1231                            } else if let Some(error) =
1232                                error.downcast_ref::<ModelRequestLimitReachedError>()
1233                            {
1234                                cx.emit(ThreadEvent::ShowError(
1235                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1236                                ));
1237                            } else if let Some(known_error) =
1238                                error.downcast_ref::<LanguageModelKnownError>()
1239                            {
1240                                match known_error {
1241                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1242                                        tokens,
1243                                    } => {
1244                                        thread.exceeded_window_error = Some(ExceededWindowError {
1245                                            model_id: model.id(),
1246                                            token_count: *tokens,
1247                                        });
1248                                        cx.notify();
1249                                    }
1250                                }
1251                            } else {
1252                                let error_message = error
1253                                    .chain()
1254                                    .map(|err| err.to_string())
1255                                    .collect::<Vec<_>>()
1256                                    .join("\n");
1257                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1258                                    header: "Error interacting with language model".into(),
1259                                    message: SharedString::from(error_message.clone()),
1260                                }));
1261                            }
1262
1263                            thread.cancel_last_completion(cx);
1264                        }
1265                    }
1266                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1267
1268                    thread.auto_capture_telemetry(cx);
1269
1270                    if let Ok(initial_usage) = initial_token_usage {
1271                        let usage = thread.cumulative_token_usage - initial_usage;
1272
1273                        telemetry::event!(
1274                            "Assistant Thread Completion",
1275                            thread_id = thread.id().to_string(),
1276                            model = model.telemetry_id(),
1277                            model_provider = model.provider_id().to_string(),
1278                            input_tokens = usage.input_tokens,
1279                            output_tokens = usage.output_tokens,
1280                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1281                            cache_read_input_tokens = usage.cache_read_input_tokens,
1282                        );
1283                    }
1284                })
1285                .ok();
1286        });
1287
1288        self.pending_completions.push(PendingCompletion {
1289            id: pending_completion_id,
1290            _task: task,
1291        });
1292    }
1293
1294    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1295        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1296            return;
1297        };
1298
1299        if !model.provider.is_authenticated(cx) {
1300            return;
1301        }
1302
1303        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1304        request.messages.push(LanguageModelRequestMessage {
1305            role: Role::User,
1306            content: vec![
1307                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1308                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1309                 If the conversation is about a specific subject, include it in the title. \
1310                 Be descriptive. DO NOT speak in the first person."
1311                    .into(),
1312            ],
1313            cache: false,
1314        });
1315
1316        self.pending_summary = cx.spawn(async move |this, cx| {
1317            async move {
1318                let stream = model.model.stream_completion_text_with_usage(request, &cx);
1319                let (mut messages, usage) = stream.await?;
1320
1321                if let Some(usage) = usage {
1322                    this.update(cx, |_thread, cx| {
1323                        cx.emit(ThreadEvent::UsageUpdated(usage));
1324                    })
1325                    .ok();
1326                }
1327
1328                let mut new_summary = String::new();
1329                while let Some(message) = messages.stream.next().await {
1330                    let text = message?;
1331                    let mut lines = text.lines();
1332                    new_summary.extend(lines.next());
1333
1334                    // Stop if the LLM generated multiple lines.
1335                    if lines.next().is_some() {
1336                        break;
1337                    }
1338                }
1339
1340                this.update(cx, |this, cx| {
1341                    if !new_summary.is_empty() {
1342                        this.summary = Some(new_summary.into());
1343                    }
1344
1345                    cx.emit(ThreadEvent::SummaryGenerated);
1346                })?;
1347
1348                anyhow::Ok(())
1349            }
1350            .log_err()
1351            .await
1352        });
1353    }
1354
1355    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1356        let last_message_id = self.messages.last().map(|message| message.id)?;
1357
1358        match &self.detailed_summary_state {
1359            DetailedSummaryState::Generating { message_id, .. }
1360            | DetailedSummaryState::Generated { message_id, .. }
1361                if *message_id == last_message_id =>
1362            {
1363                // Already up-to-date
1364                return None;
1365            }
1366            _ => {}
1367        }
1368
1369        let ConfiguredModel { model, provider } =
1370            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1371
1372        if !provider.is_authenticated(cx) {
1373            return None;
1374        }
1375
1376        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1377
1378        request.messages.push(LanguageModelRequestMessage {
1379            role: Role::User,
1380            content: vec![
1381                "Generate a detailed summary of this conversation. Include:\n\
1382                1. A brief overview of what was discussed\n\
1383                2. Key facts or information discovered\n\
1384                3. Outcomes or conclusions reached\n\
1385                4. Any action items or next steps if any\n\
1386                Format it in Markdown with headings and bullet points."
1387                    .into(),
1388            ],
1389            cache: false,
1390        });
1391
1392        let task = cx.spawn(async move |thread, cx| {
1393            let stream = model.stream_completion_text(request, &cx);
1394            let Some(mut messages) = stream.await.log_err() else {
1395                thread
1396                    .update(cx, |this, _cx| {
1397                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1398                    })
1399                    .log_err();
1400
1401                return;
1402            };
1403
1404            let mut new_detailed_summary = String::new();
1405
1406            while let Some(chunk) = messages.stream.next().await {
1407                if let Some(chunk) = chunk.log_err() {
1408                    new_detailed_summary.push_str(&chunk);
1409                }
1410            }
1411
1412            thread
1413                .update(cx, |this, _cx| {
1414                    this.detailed_summary_state = DetailedSummaryState::Generated {
1415                        text: new_detailed_summary.into(),
1416                        message_id: last_message_id,
1417                    };
1418                })
1419                .log_err();
1420        });
1421
1422        self.detailed_summary_state = DetailedSummaryState::Generating {
1423            message_id: last_message_id,
1424        };
1425
1426        Some(task)
1427    }
1428
1429    pub fn is_generating_detailed_summary(&self) -> bool {
1430        matches!(
1431            self.detailed_summary_state,
1432            DetailedSummaryState::Generating { .. }
1433        )
1434    }
1435
1436    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1437        self.auto_capture_telemetry(cx);
1438        let request = self.to_completion_request(RequestKind::Chat, cx);
1439        let messages = Arc::new(request.messages);
1440        let pending_tool_uses = self
1441            .tool_use
1442            .pending_tool_uses()
1443            .into_iter()
1444            .filter(|tool_use| tool_use.status.is_idle())
1445            .cloned()
1446            .collect::<Vec<_>>();
1447
1448        for tool_use in pending_tool_uses.iter() {
1449            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1450                if tool.needs_confirmation(&tool_use.input, cx)
1451                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1452                {
1453                    self.tool_use.confirm_tool_use(
1454                        tool_use.id.clone(),
1455                        tool_use.ui_text.clone(),
1456                        tool_use.input.clone(),
1457                        messages.clone(),
1458                        tool,
1459                    );
1460                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1461                } else {
1462                    self.run_tool(
1463                        tool_use.id.clone(),
1464                        tool_use.ui_text.clone(),
1465                        tool_use.input.clone(),
1466                        &messages,
1467                        tool,
1468                        cx,
1469                    );
1470                }
1471            }
1472        }
1473
1474        pending_tool_uses
1475    }
1476
1477    pub fn run_tool(
1478        &mut self,
1479        tool_use_id: LanguageModelToolUseId,
1480        ui_text: impl Into<SharedString>,
1481        input: serde_json::Value,
1482        messages: &[LanguageModelRequestMessage],
1483        tool: Arc<dyn Tool>,
1484        cx: &mut Context<Thread>,
1485    ) {
1486        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1487        self.tool_use
1488            .run_pending_tool(tool_use_id, ui_text.into(), task);
1489    }
1490
1491    fn spawn_tool_use(
1492        &mut self,
1493        tool_use_id: LanguageModelToolUseId,
1494        messages: &[LanguageModelRequestMessage],
1495        input: serde_json::Value,
1496        tool: Arc<dyn Tool>,
1497        cx: &mut Context<Thread>,
1498    ) -> Task<()> {
1499        let tool_name: Arc<str> = tool.name().into();
1500
1501        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1502            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1503        } else {
1504            tool.run(
1505                input,
1506                messages,
1507                self.project.clone(),
1508                self.action_log.clone(),
1509                cx,
1510            )
1511        };
1512
1513        // Store the card separately if it exists
1514        if let Some(card) = tool_result.card.clone() {
1515            self.tool_use
1516                .insert_tool_result_card(tool_use_id.clone(), card);
1517        }
1518
1519        cx.spawn({
1520            async move |thread: WeakEntity<Thread>, cx| {
1521                let output = tool_result.output.await;
1522
1523                thread
1524                    .update(cx, |thread, cx| {
1525                        let pending_tool_use = thread.tool_use.insert_tool_output(
1526                            tool_use_id.clone(),
1527                            tool_name,
1528                            output,
1529                            cx,
1530                        );
1531                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1532                    })
1533                    .ok();
1534            }
1535        })
1536    }
1537
1538    fn tool_finished(
1539        &mut self,
1540        tool_use_id: LanguageModelToolUseId,
1541        pending_tool_use: Option<PendingToolUse>,
1542        canceled: bool,
1543        cx: &mut Context<Self>,
1544    ) {
1545        if self.all_tools_finished() {
1546            let model_registry = LanguageModelRegistry::read_global(cx);
1547            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1548                self.attach_tool_results(cx);
1549                if !canceled {
1550                    self.send_to_model(model, RequestKind::Chat, cx);
1551                }
1552            }
1553        }
1554
1555        cx.emit(ThreadEvent::ToolFinished {
1556            tool_use_id,
1557            pending_tool_use,
1558        });
1559    }
1560
1561    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1562        // Insert a user message to contain the tool results.
1563        self.insert_user_message(
1564            // TODO: Sending up a user message without any content results in the model sending back
1565            // responses that also don't have any content. We currently don't handle this case well,
1566            // so for now we provide some text to keep the model on track.
1567            "Here are the tool results.",
1568            Vec::new(),
1569            None,
1570            cx,
1571        );
1572    }
1573
1574    /// Cancels the last pending completion, if there are any pending.
1575    ///
1576    /// Returns whether a completion was canceled.
1577    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1578        let canceled = if self.pending_completions.pop().is_some() {
1579            true
1580        } else {
1581            let mut canceled = false;
1582            for pending_tool_use in self.tool_use.cancel_pending() {
1583                canceled = true;
1584                self.tool_finished(
1585                    pending_tool_use.id.clone(),
1586                    Some(pending_tool_use),
1587                    true,
1588                    cx,
1589                );
1590            }
1591            canceled
1592        };
1593        self.finalize_pending_checkpoint(cx);
1594        canceled
1595    }
1596
1597    pub fn feedback(&self) -> Option<ThreadFeedback> {
1598        self.feedback
1599    }
1600
1601    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1602        self.message_feedback.get(&message_id).copied()
1603    }
1604
1605    pub fn report_message_feedback(
1606        &mut self,
1607        message_id: MessageId,
1608        feedback: ThreadFeedback,
1609        cx: &mut Context<Self>,
1610    ) -> Task<Result<()>> {
1611        if self.message_feedback.get(&message_id) == Some(&feedback) {
1612            return Task::ready(Ok(()));
1613        }
1614
1615        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1616        let serialized_thread = self.serialize(cx);
1617        let thread_id = self.id().clone();
1618        let client = self.project.read(cx).client();
1619
1620        let enabled_tool_names: Vec<String> = self
1621            .tools()
1622            .read(cx)
1623            .enabled_tools(cx)
1624            .iter()
1625            .map(|tool| tool.name().to_string())
1626            .collect();
1627
1628        self.message_feedback.insert(message_id, feedback);
1629
1630        cx.notify();
1631
1632        let message_content = self
1633            .message(message_id)
1634            .map(|msg| msg.to_string())
1635            .unwrap_or_default();
1636
1637        cx.background_spawn(async move {
1638            let final_project_snapshot = final_project_snapshot.await;
1639            let serialized_thread = serialized_thread.await?;
1640            let thread_data =
1641                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1642
1643            let rating = match feedback {
1644                ThreadFeedback::Positive => "positive",
1645                ThreadFeedback::Negative => "negative",
1646            };
1647            telemetry::event!(
1648                "Assistant Thread Rated",
1649                rating,
1650                thread_id,
1651                enabled_tool_names,
1652                message_id = message_id.0,
1653                message_content,
1654                thread_data,
1655                final_project_snapshot
1656            );
1657            client.telemetry().flush_events();
1658
1659            Ok(())
1660        })
1661    }
1662
1663    pub fn report_feedback(
1664        &mut self,
1665        feedback: ThreadFeedback,
1666        cx: &mut Context<Self>,
1667    ) -> Task<Result<()>> {
1668        let last_assistant_message_id = self
1669            .messages
1670            .iter()
1671            .rev()
1672            .find(|msg| msg.role == Role::Assistant)
1673            .map(|msg| msg.id);
1674
1675        if let Some(message_id) = last_assistant_message_id {
1676            self.report_message_feedback(message_id, feedback, cx)
1677        } else {
1678            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1679            let serialized_thread = self.serialize(cx);
1680            let thread_id = self.id().clone();
1681            let client = self.project.read(cx).client();
1682            self.feedback = Some(feedback);
1683            cx.notify();
1684
1685            cx.background_spawn(async move {
1686                let final_project_snapshot = final_project_snapshot.await;
1687                let serialized_thread = serialized_thread.await?;
1688                let thread_data = serde_json::to_value(serialized_thread)
1689                    .unwrap_or_else(|_| serde_json::Value::Null);
1690
1691                let rating = match feedback {
1692                    ThreadFeedback::Positive => "positive",
1693                    ThreadFeedback::Negative => "negative",
1694                };
1695                telemetry::event!(
1696                    "Assistant Thread Rated",
1697                    rating,
1698                    thread_id,
1699                    thread_data,
1700                    final_project_snapshot
1701                );
1702                client.telemetry().flush_events();
1703
1704                Ok(())
1705            })
1706        }
1707    }
1708
1709    /// Create a snapshot of the current project state including git information and unsaved buffers.
1710    fn project_snapshot(
1711        project: Entity<Project>,
1712        cx: &mut Context<Self>,
1713    ) -> Task<Arc<ProjectSnapshot>> {
1714        let git_store = project.read(cx).git_store().clone();
1715        let worktree_snapshots: Vec<_> = project
1716            .read(cx)
1717            .visible_worktrees(cx)
1718            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1719            .collect();
1720
1721        cx.spawn(async move |_, cx| {
1722            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1723
1724            let mut unsaved_buffers = Vec::new();
1725            cx.update(|app_cx| {
1726                let buffer_store = project.read(app_cx).buffer_store();
1727                for buffer_handle in buffer_store.read(app_cx).buffers() {
1728                    let buffer = buffer_handle.read(app_cx);
1729                    if buffer.is_dirty() {
1730                        if let Some(file) = buffer.file() {
1731                            let path = file.path().to_string_lossy().to_string();
1732                            unsaved_buffers.push(path);
1733                        }
1734                    }
1735                }
1736            })
1737            .ok();
1738
1739            Arc::new(ProjectSnapshot {
1740                worktree_snapshots,
1741                unsaved_buffer_paths: unsaved_buffers,
1742                timestamp: Utc::now(),
1743            })
1744        })
1745    }
1746
1747    fn worktree_snapshot(
1748        worktree: Entity<project::Worktree>,
1749        git_store: Entity<GitStore>,
1750        cx: &App,
1751    ) -> Task<WorktreeSnapshot> {
1752        cx.spawn(async move |cx| {
1753            // Get worktree path and snapshot
1754            let worktree_info = cx.update(|app_cx| {
1755                let worktree = worktree.read(app_cx);
1756                let path = worktree.abs_path().to_string_lossy().to_string();
1757                let snapshot = worktree.snapshot();
1758                (path, snapshot)
1759            });
1760
1761            let Ok((worktree_path, _snapshot)) = worktree_info else {
1762                return WorktreeSnapshot {
1763                    worktree_path: String::new(),
1764                    git_state: None,
1765                };
1766            };
1767
1768            let git_state = git_store
1769                .update(cx, |git_store, cx| {
1770                    git_store
1771                        .repositories()
1772                        .values()
1773                        .find(|repo| {
1774                            repo.read(cx)
1775                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1776                                .is_some()
1777                        })
1778                        .cloned()
1779                })
1780                .ok()
1781                .flatten()
1782                .map(|repo| {
1783                    repo.update(cx, |repo, _| {
1784                        let current_branch =
1785                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1786                        repo.send_job(None, |state, _| async move {
1787                            let RepositoryState::Local { backend, .. } = state else {
1788                                return GitState {
1789                                    remote_url: None,
1790                                    head_sha: None,
1791                                    current_branch,
1792                                    diff: None,
1793                                };
1794                            };
1795
1796                            let remote_url = backend.remote_url("origin");
1797                            let head_sha = backend.head_sha();
1798                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1799
1800                            GitState {
1801                                remote_url,
1802                                head_sha,
1803                                current_branch,
1804                                diff,
1805                            }
1806                        })
1807                    })
1808                });
1809
1810            let git_state = match git_state {
1811                Some(git_state) => match git_state.ok() {
1812                    Some(git_state) => git_state.await.ok(),
1813                    None => None,
1814                },
1815                None => None,
1816            };
1817
1818            WorktreeSnapshot {
1819                worktree_path,
1820                git_state,
1821            }
1822        })
1823    }
1824
1825    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1826        let mut markdown = Vec::new();
1827
1828        if let Some(summary) = self.summary() {
1829            writeln!(markdown, "# {summary}\n")?;
1830        };
1831
1832        for message in self.messages() {
1833            writeln!(
1834                markdown,
1835                "## {role}\n",
1836                role = match message.role {
1837                    Role::User => "User",
1838                    Role::Assistant => "Assistant",
1839                    Role::System => "System",
1840                }
1841            )?;
1842
1843            if !message.context.is_empty() {
1844                writeln!(markdown, "{}", message.context)?;
1845            }
1846
1847            for segment in &message.segments {
1848                match segment {
1849                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1850                    MessageSegment::Thinking(text) => {
1851                        writeln!(markdown, "<think>{}</think>\n", text)?
1852                    }
1853                }
1854            }
1855
1856            for tool_use in self.tool_uses_for_message(message.id, cx) {
1857                writeln!(
1858                    markdown,
1859                    "**Use Tool: {} ({})**",
1860                    tool_use.name, tool_use.id
1861                )?;
1862                writeln!(markdown, "```json")?;
1863                writeln!(
1864                    markdown,
1865                    "{}",
1866                    serde_json::to_string_pretty(&tool_use.input)?
1867                )?;
1868                writeln!(markdown, "```")?;
1869            }
1870
1871            for tool_result in self.tool_results_for_message(message.id) {
1872                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1873                if tool_result.is_error {
1874                    write!(markdown, " (Error)")?;
1875                }
1876
1877                writeln!(markdown, "**\n")?;
1878                writeln!(markdown, "{}", tool_result.content)?;
1879            }
1880        }
1881
1882        Ok(String::from_utf8_lossy(&markdown).to_string())
1883    }
1884
1885    pub fn keep_edits_in_range(
1886        &mut self,
1887        buffer: Entity<language::Buffer>,
1888        buffer_range: Range<language::Anchor>,
1889        cx: &mut Context<Self>,
1890    ) {
1891        self.action_log.update(cx, |action_log, cx| {
1892            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1893        });
1894    }
1895
1896    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1897        self.action_log
1898            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1899    }
1900
1901    pub fn reject_edits_in_ranges(
1902        &mut self,
1903        buffer: Entity<language::Buffer>,
1904        buffer_ranges: Vec<Range<language::Anchor>>,
1905        cx: &mut Context<Self>,
1906    ) -> Task<Result<()>> {
1907        self.action_log.update(cx, |action_log, cx| {
1908            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1909        })
1910    }
1911
1912    pub fn action_log(&self) -> &Entity<ActionLog> {
1913        &self.action_log
1914    }
1915
1916    pub fn project(&self) -> &Entity<Project> {
1917        &self.project
1918    }
1919
1920    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1921        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1922            return;
1923        }
1924
1925        let now = Instant::now();
1926        if let Some(last) = self.last_auto_capture_at {
1927            if now.duration_since(last).as_secs() < 10 {
1928                return;
1929            }
1930        }
1931
1932        self.last_auto_capture_at = Some(now);
1933
1934        let thread_id = self.id().clone();
1935        let github_login = self
1936            .project
1937            .read(cx)
1938            .user_store()
1939            .read(cx)
1940            .current_user()
1941            .map(|user| user.github_login.clone());
1942        let client = self.project.read(cx).client().clone();
1943        let serialize_task = self.serialize(cx);
1944
1945        cx.background_executor()
1946            .spawn(async move {
1947                if let Ok(serialized_thread) = serialize_task.await {
1948                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1949                        telemetry::event!(
1950                            "Agent Thread Auto-Captured",
1951                            thread_id = thread_id.to_string(),
1952                            thread_data = thread_data,
1953                            auto_capture_reason = "tracked_user",
1954                            github_login = github_login
1955                        );
1956
1957                        client.telemetry().flush_events();
1958                    }
1959                }
1960            })
1961            .detach();
1962    }
1963
1964    pub fn cumulative_token_usage(&self) -> TokenUsage {
1965        self.cumulative_token_usage
1966    }
1967
1968    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1969        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1970            return TotalTokenUsage::default();
1971        };
1972
1973        let max = model.model.max_token_count();
1974
1975        let index = self
1976            .messages
1977            .iter()
1978            .position(|msg| msg.id == message_id)
1979            .unwrap_or(0);
1980
1981        if index == 0 {
1982            return TotalTokenUsage { total: 0, max };
1983        }
1984
1985        let token_usage = &self
1986            .request_token_usage
1987            .get(index - 1)
1988            .cloned()
1989            .unwrap_or_default();
1990
1991        TotalTokenUsage {
1992            total: token_usage.total_tokens() as usize,
1993            max,
1994        }
1995    }
1996
1997    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1998        let model_registry = LanguageModelRegistry::read_global(cx);
1999        let Some(model) = model_registry.default_model() else {
2000            return TotalTokenUsage::default();
2001        };
2002
2003        let max = model.model.max_token_count();
2004
2005        if let Some(exceeded_error) = &self.exceeded_window_error {
2006            if model.model.id() == exceeded_error.model_id {
2007                return TotalTokenUsage {
2008                    total: exceeded_error.token_count,
2009                    max,
2010                };
2011            }
2012        }
2013
2014        let total = self
2015            .token_usage_at_last_message()
2016            .unwrap_or_default()
2017            .total_tokens() as usize;
2018
2019        TotalTokenUsage { total, max }
2020    }
2021
2022    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2023        self.request_token_usage
2024            .get(self.messages.len().saturating_sub(1))
2025            .or_else(|| self.request_token_usage.last())
2026            .cloned()
2027    }
2028
2029    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2030        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2031        self.request_token_usage
2032            .resize(self.messages.len(), placeholder);
2033
2034        if let Some(last) = self.request_token_usage.last_mut() {
2035            *last = token_usage;
2036        }
2037    }
2038
2039    pub fn deny_tool_use(
2040        &mut self,
2041        tool_use_id: LanguageModelToolUseId,
2042        tool_name: Arc<str>,
2043        cx: &mut Context<Self>,
2044    ) {
2045        let err = Err(anyhow::anyhow!(
2046            "Permission to run tool action denied by user"
2047        ));
2048
2049        self.tool_use
2050            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2051        self.tool_finished(tool_use_id.clone(), None, true, cx);
2052    }
2053}
2054
2055#[derive(Debug, Clone, Error)]
2056pub enum ThreadError {
2057    #[error("Payment required")]
2058    PaymentRequired,
2059    #[error("Max monthly spend reached")]
2060    MaxMonthlySpendReached,
2061    #[error("Model request limit reached")]
2062    ModelRequestLimitReached { plan: Plan },
2063    #[error("Message {header}: {message}")]
2064    Message {
2065        header: SharedString,
2066        message: SharedString,
2067    },
2068}
2069
2070#[derive(Debug, Clone)]
2071pub enum ThreadEvent {
2072    ShowError(ThreadError),
2073    UsageUpdated(RequestUsage),
2074    StreamedCompletion,
2075    StreamedAssistantText(MessageId, String),
2076    StreamedAssistantThinking(MessageId, String),
2077    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2078    MessageAdded(MessageId),
2079    MessageEdited(MessageId),
2080    MessageDeleted(MessageId),
2081    SummaryGenerated,
2082    SummaryChanged,
2083    UsePendingTools {
2084        tool_uses: Vec<PendingToolUse>,
2085    },
2086    ToolFinished {
2087        #[allow(unused)]
2088        tool_use_id: LanguageModelToolUseId,
2089        /// The pending tool use that corresponds to this tool.
2090        pending_tool_use: Option<PendingToolUse>,
2091    },
2092    CheckpointChanged,
2093    ToolConfirmationNeeded,
2094}
2095
2096impl EventEmitter<ThreadEvent> for Thread {}
2097
2098struct PendingCompletion {
2099    id: usize,
2100    _task: Task<()>,
2101}
2102
2103#[cfg(test)]
2104mod tests {
2105    use super::*;
2106    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2107    use assistant_settings::AssistantSettings;
2108    use context_server::ContextServerSettings;
2109    use editor::EditorSettings;
2110    use gpui::TestAppContext;
2111    use project::{FakeFs, Project};
2112    use prompt_store::PromptBuilder;
2113    use serde_json::json;
2114    use settings::{Settings, SettingsStore};
2115    use std::sync::Arc;
2116    use theme::ThemeSettings;
2117    use util::path;
2118    use workspace::Workspace;
2119
2120    #[gpui::test]
2121    async fn test_message_with_context(cx: &mut TestAppContext) {
2122        init_test_settings(cx);
2123
2124        let project = create_test_project(
2125            cx,
2126            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2127        )
2128        .await;
2129
2130        let (_workspace, _thread_store, thread, context_store) =
2131            setup_test_environment(cx, project.clone()).await;
2132
2133        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2134            .await
2135            .unwrap();
2136
2137        let context =
2138            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2139
2140        // Insert user message with context
2141        let message_id = thread.update(cx, |thread, cx| {
2142            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2143        });
2144
2145        // Check content and context in message object
2146        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2147
2148        // Use different path format strings based on platform for the test
2149        #[cfg(windows)]
2150        let path_part = r"test\code.rs";
2151        #[cfg(not(windows))]
2152        let path_part = "test/code.rs";
2153
2154        let expected_context = format!(
2155            r#"
2156<context>
2157The following items were attached by the user. You don't need to use other tools to read them.
2158
2159<files>
2160```rs {path_part}
2161fn main() {{
2162    println!("Hello, world!");
2163}}
2164```
2165</files>
2166</context>
2167"#
2168        );
2169
2170        assert_eq!(message.role, Role::User);
2171        assert_eq!(message.segments.len(), 1);
2172        assert_eq!(
2173            message.segments[0],
2174            MessageSegment::Text("Please explain this code".to_string())
2175        );
2176        assert_eq!(message.context, expected_context);
2177
2178        // Check message in request
2179        let request = thread.update(cx, |thread, cx| {
2180            thread.to_completion_request(RequestKind::Chat, cx)
2181        });
2182
2183        assert_eq!(request.messages.len(), 2);
2184        let expected_full_message = format!("{}Please explain this code", expected_context);
2185        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2186    }
2187
2188    #[gpui::test]
2189    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2190        init_test_settings(cx);
2191
2192        let project = create_test_project(
2193            cx,
2194            json!({
2195                "file1.rs": "fn function1() {}\n",
2196                "file2.rs": "fn function2() {}\n",
2197                "file3.rs": "fn function3() {}\n",
2198            }),
2199        )
2200        .await;
2201
2202        let (_, _thread_store, thread, context_store) =
2203            setup_test_environment(cx, project.clone()).await;
2204
2205        // Open files individually
2206        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2207            .await
2208            .unwrap();
2209        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2210            .await
2211            .unwrap();
2212        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2213            .await
2214            .unwrap();
2215
2216        // Get the context objects
2217        let contexts = context_store.update(cx, |store, _| store.context().clone());
2218        assert_eq!(contexts.len(), 3);
2219
2220        // First message with context 1
2221        let message1_id = thread.update(cx, |thread, cx| {
2222            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2223        });
2224
2225        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2226        let message2_id = thread.update(cx, |thread, cx| {
2227            thread.insert_user_message(
2228                "Message 2",
2229                vec![contexts[0].clone(), contexts[1].clone()],
2230                None,
2231                cx,
2232            )
2233        });
2234
2235        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2236        let message3_id = thread.update(cx, |thread, cx| {
2237            thread.insert_user_message(
2238                "Message 3",
2239                vec![
2240                    contexts[0].clone(),
2241                    contexts[1].clone(),
2242                    contexts[2].clone(),
2243                ],
2244                None,
2245                cx,
2246            )
2247        });
2248
2249        // Check what contexts are included in each message
2250        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2251            (
2252                thread.message(message1_id).unwrap().clone(),
2253                thread.message(message2_id).unwrap().clone(),
2254                thread.message(message3_id).unwrap().clone(),
2255            )
2256        });
2257
2258        // First message should include context 1
2259        assert!(message1.context.contains("file1.rs"));
2260
2261        // Second message should include only context 2 (not 1)
2262        assert!(!message2.context.contains("file1.rs"));
2263        assert!(message2.context.contains("file2.rs"));
2264
2265        // Third message should include only context 3 (not 1 or 2)
2266        assert!(!message3.context.contains("file1.rs"));
2267        assert!(!message3.context.contains("file2.rs"));
2268        assert!(message3.context.contains("file3.rs"));
2269
2270        // Check entire request to make sure all contexts are properly included
2271        let request = thread.update(cx, |thread, cx| {
2272            thread.to_completion_request(RequestKind::Chat, cx)
2273        });
2274
2275        // The request should contain all 3 messages
2276        assert_eq!(request.messages.len(), 4);
2277
2278        // Check that the contexts are properly formatted in each message
2279        assert!(request.messages[1].string_contents().contains("file1.rs"));
2280        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2281        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2282
2283        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2284        assert!(request.messages[2].string_contents().contains("file2.rs"));
2285        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2286
2287        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2288        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2289        assert!(request.messages[3].string_contents().contains("file3.rs"));
2290    }
2291
2292    #[gpui::test]
2293    async fn test_message_without_files(cx: &mut TestAppContext) {
2294        init_test_settings(cx);
2295
2296        let project = create_test_project(
2297            cx,
2298            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2299        )
2300        .await;
2301
2302        let (_, _thread_store, thread, _context_store) =
2303            setup_test_environment(cx, project.clone()).await;
2304
2305        // Insert user message without any context (empty context vector)
2306        let message_id = thread.update(cx, |thread, cx| {
2307            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2308        });
2309
2310        // Check content and context in message object
2311        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2312
2313        // Context should be empty when no files are included
2314        assert_eq!(message.role, Role::User);
2315        assert_eq!(message.segments.len(), 1);
2316        assert_eq!(
2317            message.segments[0],
2318            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2319        );
2320        assert_eq!(message.context, "");
2321
2322        // Check message in request
2323        let request = thread.update(cx, |thread, cx| {
2324            thread.to_completion_request(RequestKind::Chat, cx)
2325        });
2326
2327        assert_eq!(request.messages.len(), 2);
2328        assert_eq!(
2329            request.messages[1].string_contents(),
2330            "What is the best way to learn Rust?"
2331        );
2332
2333        // Add second message, also without context
2334        let message2_id = thread.update(cx, |thread, cx| {
2335            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2336        });
2337
2338        let message2 =
2339            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2340        assert_eq!(message2.context, "");
2341
2342        // Check that both messages appear in the request
2343        let request = thread.update(cx, |thread, cx| {
2344            thread.to_completion_request(RequestKind::Chat, cx)
2345        });
2346
2347        assert_eq!(request.messages.len(), 3);
2348        assert_eq!(
2349            request.messages[1].string_contents(),
2350            "What is the best way to learn Rust?"
2351        );
2352        assert_eq!(
2353            request.messages[2].string_contents(),
2354            "Are there any good books?"
2355        );
2356    }
2357
2358    #[gpui::test]
2359    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2360        init_test_settings(cx);
2361
2362        let project = create_test_project(
2363            cx,
2364            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2365        )
2366        .await;
2367
2368        let (_workspace, _thread_store, thread, context_store) =
2369            setup_test_environment(cx, project.clone()).await;
2370
2371        // Open buffer and add it to context
2372        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2373            .await
2374            .unwrap();
2375
2376        let context =
2377            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2378
2379        // Insert user message with the buffer as context
2380        thread.update(cx, |thread, cx| {
2381            thread.insert_user_message("Explain this code", vec![context], None, cx)
2382        });
2383
2384        // Create a request and check that it doesn't have a stale buffer warning yet
2385        let initial_request = thread.update(cx, |thread, cx| {
2386            thread.to_completion_request(RequestKind::Chat, cx)
2387        });
2388
2389        // Make sure we don't have a stale file warning yet
2390        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2391            msg.string_contents()
2392                .contains("These files changed since last read:")
2393        });
2394        assert!(
2395            !has_stale_warning,
2396            "Should not have stale buffer warning before buffer is modified"
2397        );
2398
2399        // Modify the buffer
2400        buffer.update(cx, |buffer, cx| {
2401            // Find a position at the end of line 1
2402            buffer.edit(
2403                [(1..1, "\n    println!(\"Added a new line\");\n")],
2404                None,
2405                cx,
2406            );
2407        });
2408
2409        // Insert another user message without context
2410        thread.update(cx, |thread, cx| {
2411            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2412        });
2413
2414        // Create a new request and check for the stale buffer warning
2415        let new_request = thread.update(cx, |thread, cx| {
2416            thread.to_completion_request(RequestKind::Chat, cx)
2417        });
2418
2419        // We should have a stale file warning as the last message
2420        let last_message = new_request
2421            .messages
2422            .last()
2423            .expect("Request should have messages");
2424
2425        // The last message should be the stale buffer notification
2426        assert_eq!(last_message.role, Role::User);
2427
2428        // Check the exact content of the message
2429        let expected_content = "These files changed since last read:\n- code.rs\n";
2430        assert_eq!(
2431            last_message.string_contents(),
2432            expected_content,
2433            "Last message should be exactly the stale buffer notification"
2434        );
2435    }
2436
2437    fn init_test_settings(cx: &mut TestAppContext) {
2438        cx.update(|cx| {
2439            let settings_store = SettingsStore::test(cx);
2440            cx.set_global(settings_store);
2441            language::init(cx);
2442            Project::init_settings(cx);
2443            AssistantSettings::register(cx);
2444            prompt_store::init(cx);
2445            thread_store::init(cx);
2446            workspace::init_settings(cx);
2447            ThemeSettings::register(cx);
2448            ContextServerSettings::register(cx);
2449            EditorSettings::register(cx);
2450        });
2451    }
2452
2453    // Helper to create a test project with test files
2454    async fn create_test_project(
2455        cx: &mut TestAppContext,
2456        files: serde_json::Value,
2457    ) -> Entity<Project> {
2458        let fs = FakeFs::new(cx.executor());
2459        fs.insert_tree(path!("/test"), files).await;
2460        Project::test(fs, [path!("/test").as_ref()], cx).await
2461    }
2462
2463    async fn setup_test_environment(
2464        cx: &mut TestAppContext,
2465        project: Entity<Project>,
2466    ) -> (
2467        Entity<Workspace>,
2468        Entity<ThreadStore>,
2469        Entity<Thread>,
2470        Entity<ContextStore>,
2471    ) {
2472        let (workspace, cx) =
2473            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2474
2475        let thread_store = cx
2476            .update(|_, cx| {
2477                ThreadStore::load(
2478                    project.clone(),
2479                    cx.new(|_| ToolWorkingSet::default()),
2480                    Arc::new(PromptBuilder::new(None).unwrap()),
2481                    cx,
2482                )
2483            })
2484            .await
2485            .unwrap();
2486
2487        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2488        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2489
2490        (workspace, thread_store, thread, context_store)
2491    }
2492
2493    async fn add_file_to_context(
2494        project: &Entity<Project>,
2495        context_store: &Entity<ContextStore>,
2496        path: &str,
2497        cx: &mut TestAppContext,
2498    ) -> Result<Entity<language::Buffer>> {
2499        let buffer_path = project
2500            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2501            .unwrap();
2502
2503        let buffer = project
2504            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2505            .await
2506            .unwrap();
2507
2508        context_store
2509            .update(cx, |store, cx| {
2510                store.add_file_from_buffer(buffer.clone(), cx)
2511            })
2512            .await?;
2513
2514        Ok(buffer)
2515    }
2516}