thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
  23};
  24use project::Project;
  25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  26use prompt_store::PromptBuilder;
  27use proto::Plan;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings;
  31use thiserror::Error;
  32use util::{ResultExt as _, TryFutureExt as _, post_inc};
  33use uuid::Uuid;
  34
  35use crate::context::{AssistantContext, ContextId, format_context_as_string};
  36use crate::thread_store::{
  37    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  38    SerializedToolUse, SharedProjectContext,
  39};
  40use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  41
  42#[derive(Debug, Clone, Copy)]
  43pub enum RequestKind {
  44    Chat,
  45    /// Used when summarizing a thread.
  46    Summarize,
  47}
  48
  49#[derive(
  50    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  51)]
  52pub struct ThreadId(Arc<str>);
  53
  54impl ThreadId {
  55    pub fn new() -> Self {
  56        Self(Uuid::new_v4().to_string().into())
  57    }
  58}
  59
  60impl std::fmt::Display for ThreadId {
  61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  62        write!(f, "{}", self.0)
  63    }
  64}
  65
  66impl From<&str> for ThreadId {
  67    fn from(value: &str) -> Self {
  68        Self(value.into())
  69    }
  70}
  71
  72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  73pub struct MessageId(pub(crate) usize);
  74
  75impl MessageId {
  76    fn post_inc(&mut self) -> Self {
  77        Self(post_inc(&mut self.0))
  78    }
  79}
  80
  81/// A message in a [`Thread`].
  82#[derive(Debug, Clone)]
  83pub struct Message {
  84    pub id: MessageId,
  85    pub role: Role,
  86    pub segments: Vec<MessageSegment>,
  87    pub context: String,
  88}
  89
  90impl Message {
  91    /// Returns whether the message contains any meaningful text that should be displayed
  92    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  93    pub fn should_display_content(&self) -> bool {
  94        self.segments.iter().all(|segment| segment.should_display())
  95    }
  96
  97    pub fn push_thinking(&mut self, text: &str) {
  98        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  99            segment.push_str(text);
 100        } else {
 101            self.segments
 102                .push(MessageSegment::Thinking(text.to_string()));
 103        }
 104    }
 105
 106    pub fn push_text(&mut self, text: &str) {
 107        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 108            segment.push_str(text);
 109        } else {
 110            self.segments.push(MessageSegment::Text(text.to_string()));
 111        }
 112    }
 113
 114    pub fn to_string(&self) -> String {
 115        let mut result = String::new();
 116
 117        if !self.context.is_empty() {
 118            result.push_str(&self.context);
 119        }
 120
 121        for segment in &self.segments {
 122            match segment {
 123                MessageSegment::Text(text) => result.push_str(text),
 124                MessageSegment::Thinking(text) => {
 125                    result.push_str("<think>");
 126                    result.push_str(text);
 127                    result.push_str("</think>");
 128                }
 129            }
 130        }
 131
 132        result
 133    }
 134}
 135
 136#[derive(Debug, Clone, PartialEq, Eq)]
 137pub enum MessageSegment {
 138    Text(String),
 139    Thinking(String),
 140}
 141
 142impl MessageSegment {
 143    pub fn text_mut(&mut self) -> &mut String {
 144        match self {
 145            Self::Text(text) => text,
 146            Self::Thinking(text) => text,
 147        }
 148    }
 149
 150    pub fn should_display(&self) -> bool {
 151        // We add USING_TOOL_MARKER when making a request that includes tool uses
 152        // without non-whitespace text around them, and this can cause the model
 153        // to mimic the pattern, so we consider those segments not displayable.
 154        match self {
 155            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 156            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 157        }
 158    }
 159}
 160
 161#[derive(Debug, Clone, Serialize, Deserialize)]
 162pub struct ProjectSnapshot {
 163    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 164    pub unsaved_buffer_paths: Vec<String>,
 165    pub timestamp: DateTime<Utc>,
 166}
 167
 168#[derive(Debug, Clone, Serialize, Deserialize)]
 169pub struct WorktreeSnapshot {
 170    pub worktree_path: String,
 171    pub git_state: Option<GitState>,
 172}
 173
 174#[derive(Debug, Clone, Serialize, Deserialize)]
 175pub struct GitState {
 176    pub remote_url: Option<String>,
 177    pub head_sha: Option<String>,
 178    pub current_branch: Option<String>,
 179    pub diff: Option<String>,
 180}
 181
 182#[derive(Clone)]
 183pub struct ThreadCheckpoint {
 184    message_id: MessageId,
 185    git_checkpoint: GitStoreCheckpoint,
 186}
 187
 188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 189pub enum ThreadFeedback {
 190    Positive,
 191    Negative,
 192}
 193
 194pub enum LastRestoreCheckpoint {
 195    Pending {
 196        message_id: MessageId,
 197    },
 198    Error {
 199        message_id: MessageId,
 200        error: String,
 201    },
 202}
 203
 204impl LastRestoreCheckpoint {
 205    pub fn message_id(&self) -> MessageId {
 206        match self {
 207            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 208            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 209        }
 210    }
 211}
 212
 213#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 214pub enum DetailedSummaryState {
 215    #[default]
 216    NotGenerated,
 217    Generating {
 218        message_id: MessageId,
 219    },
 220    Generated {
 221        text: SharedString,
 222        message_id: MessageId,
 223    },
 224}
 225
 226#[derive(Default)]
 227pub struct TotalTokenUsage {
 228    pub total: usize,
 229    pub max: usize,
 230}
 231
 232impl TotalTokenUsage {
 233    pub fn ratio(&self) -> TokenUsageRatio {
 234        #[cfg(debug_assertions)]
 235        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 236            .unwrap_or("0.8".to_string())
 237            .parse()
 238            .unwrap();
 239        #[cfg(not(debug_assertions))]
 240        let warning_threshold: f32 = 0.8;
 241
 242        if self.total >= self.max {
 243            TokenUsageRatio::Exceeded
 244        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 245            TokenUsageRatio::Warning
 246        } else {
 247            TokenUsageRatio::Normal
 248        }
 249    }
 250
 251    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 252        TotalTokenUsage {
 253            total: self.total + tokens,
 254            max: self.max,
 255        }
 256    }
 257}
 258
 259#[derive(Debug, Default, PartialEq, Eq)]
 260pub enum TokenUsageRatio {
 261    #[default]
 262    Normal,
 263    Warning,
 264    Exceeded,
 265}
 266
 267/// A thread of conversation with the LLM.
 268pub struct Thread {
 269    id: ThreadId,
 270    updated_at: DateTime<Utc>,
 271    summary: Option<SharedString>,
 272    pending_summary: Task<Option<()>>,
 273    detailed_summary_state: DetailedSummaryState,
 274    messages: Vec<Message>,
 275    next_message_id: MessageId,
 276    context: BTreeMap<ContextId, AssistantContext>,
 277    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 278    project_context: SharedProjectContext,
 279    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 280    completion_count: usize,
 281    pending_completions: Vec<PendingCompletion>,
 282    project: Entity<Project>,
 283    prompt_builder: Arc<PromptBuilder>,
 284    tools: Entity<ToolWorkingSet>,
 285    tool_use: ToolUseState,
 286    action_log: Entity<ActionLog>,
 287    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 288    pending_checkpoint: Option<ThreadCheckpoint>,
 289    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 290    request_token_usage: Vec<TokenUsage>,
 291    cumulative_token_usage: TokenUsage,
 292    exceeded_window_error: Option<ExceededWindowError>,
 293    feedback: Option<ThreadFeedback>,
 294    message_feedback: HashMap<MessageId, ThreadFeedback>,
 295    last_auto_capture_at: Option<Instant>,
 296}
 297
 298#[derive(Debug, Clone, Serialize, Deserialize)]
 299pub struct ExceededWindowError {
 300    /// Model used when last message exceeded context window
 301    model_id: LanguageModelId,
 302    /// Token count including last message
 303    token_count: usize,
 304}
 305
 306impl Thread {
 307    pub fn new(
 308        project: Entity<Project>,
 309        tools: Entity<ToolWorkingSet>,
 310        prompt_builder: Arc<PromptBuilder>,
 311        system_prompt: SharedProjectContext,
 312        cx: &mut Context<Self>,
 313    ) -> Self {
 314        Self {
 315            id: ThreadId::new(),
 316            updated_at: Utc::now(),
 317            summary: None,
 318            pending_summary: Task::ready(None),
 319            detailed_summary_state: DetailedSummaryState::NotGenerated,
 320            messages: Vec::new(),
 321            next_message_id: MessageId(0),
 322            context: BTreeMap::default(),
 323            context_by_message: HashMap::default(),
 324            project_context: system_prompt,
 325            checkpoints_by_message: HashMap::default(),
 326            completion_count: 0,
 327            pending_completions: Vec::new(),
 328            project: project.clone(),
 329            prompt_builder,
 330            tools: tools.clone(),
 331            last_restore_checkpoint: None,
 332            pending_checkpoint: None,
 333            tool_use: ToolUseState::new(tools.clone()),
 334            action_log: cx.new(|_| ActionLog::new(project.clone())),
 335            initial_project_snapshot: {
 336                let project_snapshot = Self::project_snapshot(project, cx);
 337                cx.foreground_executor()
 338                    .spawn(async move { Some(project_snapshot.await) })
 339                    .shared()
 340            },
 341            request_token_usage: Vec::new(),
 342            cumulative_token_usage: TokenUsage::default(),
 343            exceeded_window_error: None,
 344            feedback: None,
 345            message_feedback: HashMap::default(),
 346            last_auto_capture_at: None,
 347        }
 348    }
 349
 350    pub fn deserialize(
 351        id: ThreadId,
 352        serialized: SerializedThread,
 353        project: Entity<Project>,
 354        tools: Entity<ToolWorkingSet>,
 355        prompt_builder: Arc<PromptBuilder>,
 356        project_context: SharedProjectContext,
 357        cx: &mut Context<Self>,
 358    ) -> Self {
 359        let next_message_id = MessageId(
 360            serialized
 361                .messages
 362                .last()
 363                .map(|message| message.id.0 + 1)
 364                .unwrap_or(0),
 365        );
 366        let tool_use =
 367            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 368
 369        Self {
 370            id,
 371            updated_at: serialized.updated_at,
 372            summary: Some(serialized.summary),
 373            pending_summary: Task::ready(None),
 374            detailed_summary_state: serialized.detailed_summary_state,
 375            messages: serialized
 376                .messages
 377                .into_iter()
 378                .map(|message| Message {
 379                    id: message.id,
 380                    role: message.role,
 381                    segments: message
 382                        .segments
 383                        .into_iter()
 384                        .map(|segment| match segment {
 385                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 386                            SerializedMessageSegment::Thinking { text } => {
 387                                MessageSegment::Thinking(text)
 388                            }
 389                        })
 390                        .collect(),
 391                    context: message.context,
 392                })
 393                .collect(),
 394            next_message_id,
 395            context: BTreeMap::default(),
 396            context_by_message: HashMap::default(),
 397            project_context,
 398            checkpoints_by_message: HashMap::default(),
 399            completion_count: 0,
 400            pending_completions: Vec::new(),
 401            last_restore_checkpoint: None,
 402            pending_checkpoint: None,
 403            project: project.clone(),
 404            prompt_builder,
 405            tools,
 406            tool_use,
 407            action_log: cx.new(|_| ActionLog::new(project)),
 408            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 409            request_token_usage: serialized.request_token_usage,
 410            cumulative_token_usage: serialized.cumulative_token_usage,
 411            exceeded_window_error: None,
 412            feedback: None,
 413            message_feedback: HashMap::default(),
 414            last_auto_capture_at: None,
 415        }
 416    }
 417
 418    pub fn id(&self) -> &ThreadId {
 419        &self.id
 420    }
 421
 422    pub fn is_empty(&self) -> bool {
 423        self.messages.is_empty()
 424    }
 425
 426    pub fn updated_at(&self) -> DateTime<Utc> {
 427        self.updated_at
 428    }
 429
 430    pub fn touch_updated_at(&mut self) {
 431        self.updated_at = Utc::now();
 432    }
 433
 434    pub fn summary(&self) -> Option<SharedString> {
 435        self.summary.clone()
 436    }
 437
 438    pub fn project_context(&self) -> SharedProjectContext {
 439        self.project_context.clone()
 440    }
 441
 442    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 443
 444    pub fn summary_or_default(&self) -> SharedString {
 445        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 446    }
 447
 448    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 449        let Some(current_summary) = &self.summary else {
 450            // Don't allow setting summary until generated
 451            return;
 452        };
 453
 454        let mut new_summary = new_summary.into();
 455
 456        if new_summary.is_empty() {
 457            new_summary = Self::DEFAULT_SUMMARY;
 458        }
 459
 460        if current_summary != &new_summary {
 461            self.summary = Some(new_summary);
 462            cx.emit(ThreadEvent::SummaryChanged);
 463        }
 464    }
 465
 466    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 467        self.latest_detailed_summary()
 468            .unwrap_or_else(|| self.text().into())
 469    }
 470
 471    fn latest_detailed_summary(&self) -> Option<SharedString> {
 472        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 473            Some(text.clone())
 474        } else {
 475            None
 476        }
 477    }
 478
 479    pub fn message(&self, id: MessageId) -> Option<&Message> {
 480        self.messages.iter().find(|message| message.id == id)
 481    }
 482
 483    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 484        self.messages.iter()
 485    }
 486
 487    pub fn is_generating(&self) -> bool {
 488        !self.pending_completions.is_empty() || !self.all_tools_finished()
 489    }
 490
 491    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 492        &self.tools
 493    }
 494
 495    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 496        self.tool_use
 497            .pending_tool_uses()
 498            .into_iter()
 499            .find(|tool_use| &tool_use.id == id)
 500    }
 501
 502    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 503        self.tool_use
 504            .pending_tool_uses()
 505            .into_iter()
 506            .filter(|tool_use| tool_use.status.needs_confirmation())
 507    }
 508
 509    pub fn has_pending_tool_uses(&self) -> bool {
 510        !self.tool_use.pending_tool_uses().is_empty()
 511    }
 512
 513    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 514        self.checkpoints_by_message.get(&id).cloned()
 515    }
 516
 517    pub fn restore_checkpoint(
 518        &mut self,
 519        checkpoint: ThreadCheckpoint,
 520        cx: &mut Context<Self>,
 521    ) -> Task<Result<()>> {
 522        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 523            message_id: checkpoint.message_id,
 524        });
 525        cx.emit(ThreadEvent::CheckpointChanged);
 526        cx.notify();
 527
 528        let git_store = self.project().read(cx).git_store().clone();
 529        let restore = git_store.update(cx, |git_store, cx| {
 530            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 531        });
 532
 533        cx.spawn(async move |this, cx| {
 534            let result = restore.await;
 535            this.update(cx, |this, cx| {
 536                if let Err(err) = result.as_ref() {
 537                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 538                        message_id: checkpoint.message_id,
 539                        error: err.to_string(),
 540                    });
 541                } else {
 542                    this.truncate(checkpoint.message_id, cx);
 543                    this.last_restore_checkpoint = None;
 544                }
 545                this.pending_checkpoint = None;
 546                cx.emit(ThreadEvent::CheckpointChanged);
 547                cx.notify();
 548            })?;
 549            result
 550        })
 551    }
 552
 553    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 554        let pending_checkpoint = if self.is_generating() {
 555            return;
 556        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 557            checkpoint
 558        } else {
 559            return;
 560        };
 561
 562        let git_store = self.project.read(cx).git_store().clone();
 563        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 564        cx.spawn(async move |this, cx| match final_checkpoint.await {
 565            Ok(final_checkpoint) => {
 566                let equal = git_store
 567                    .update(cx, |store, cx| {
 568                        store.compare_checkpoints(
 569                            pending_checkpoint.git_checkpoint.clone(),
 570                            final_checkpoint.clone(),
 571                            cx,
 572                        )
 573                    })?
 574                    .await
 575                    .unwrap_or(false);
 576
 577                if equal {
 578                    git_store
 579                        .update(cx, |store, cx| {
 580                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 581                        })?
 582                        .detach();
 583                } else {
 584                    this.update(cx, |this, cx| {
 585                        this.insert_checkpoint(pending_checkpoint, cx)
 586                    })?;
 587                }
 588
 589                git_store
 590                    .update(cx, |store, cx| {
 591                        store.delete_checkpoint(final_checkpoint, cx)
 592                    })?
 593                    .detach();
 594
 595                Ok(())
 596            }
 597            Err(_) => this.update(cx, |this, cx| {
 598                this.insert_checkpoint(pending_checkpoint, cx)
 599            }),
 600        })
 601        .detach();
 602    }
 603
 604    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 605        self.checkpoints_by_message
 606            .insert(checkpoint.message_id, checkpoint);
 607        cx.emit(ThreadEvent::CheckpointChanged);
 608        cx.notify();
 609    }
 610
 611    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 612        self.last_restore_checkpoint.as_ref()
 613    }
 614
 615    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 616        let Some(message_ix) = self
 617            .messages
 618            .iter()
 619            .rposition(|message| message.id == message_id)
 620        else {
 621            return;
 622        };
 623        for deleted_message in self.messages.drain(message_ix..) {
 624            self.context_by_message.remove(&deleted_message.id);
 625            self.checkpoints_by_message.remove(&deleted_message.id);
 626        }
 627        cx.notify();
 628    }
 629
 630    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 631        self.context_by_message
 632            .get(&id)
 633            .into_iter()
 634            .flat_map(|context| {
 635                context
 636                    .iter()
 637                    .filter_map(|context_id| self.context.get(&context_id))
 638            })
 639    }
 640
 641    /// Returns whether all of the tool uses have finished running.
 642    pub fn all_tools_finished(&self) -> bool {
 643        // If the only pending tool uses left are the ones with errors, then
 644        // that means that we've finished running all of the pending tools.
 645        self.tool_use
 646            .pending_tool_uses()
 647            .iter()
 648            .all(|tool_use| tool_use.status.is_error())
 649    }
 650
 651    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 652        self.tool_use.tool_uses_for_message(id, cx)
 653    }
 654
 655    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 656        self.tool_use.tool_results_for_message(id)
 657    }
 658
 659    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 660        self.tool_use.tool_result(id)
 661    }
 662
 663    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 664        Some(&self.tool_use.tool_result(id)?.content)
 665    }
 666
 667    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 668        self.tool_use.tool_result_card(id).cloned()
 669    }
 670
 671    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 672        self.tool_use.message_has_tool_results(message_id)
 673    }
 674
 675    /// Filter out contexts that have already been included in previous messages
 676    pub fn filter_new_context<'a>(
 677        &self,
 678        context: impl Iterator<Item = &'a AssistantContext>,
 679    ) -> impl Iterator<Item = &'a AssistantContext> {
 680        context.filter(|ctx| self.is_context_new(ctx))
 681    }
 682
 683    fn is_context_new(&self, context: &AssistantContext) -> bool {
 684        !self.context.contains_key(&context.id())
 685    }
 686
 687    pub fn insert_user_message(
 688        &mut self,
 689        text: impl Into<String>,
 690        context: Vec<AssistantContext>,
 691        git_checkpoint: Option<GitStoreCheckpoint>,
 692        cx: &mut Context<Self>,
 693    ) -> MessageId {
 694        let text = text.into();
 695
 696        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 697
 698        let new_context: Vec<_> = context
 699            .into_iter()
 700            .filter(|ctx| self.is_context_new(ctx))
 701            .collect();
 702
 703        if !new_context.is_empty() {
 704            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 705                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 706                    message.context = context_string;
 707                }
 708            }
 709
 710            self.action_log.update(cx, |log, cx| {
 711                // Track all buffers added as context
 712                for ctx in &new_context {
 713                    match ctx {
 714                        AssistantContext::File(file_ctx) => {
 715                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 716                        }
 717                        AssistantContext::Directory(dir_ctx) => {
 718                            for context_buffer in &dir_ctx.context_buffers {
 719                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 720                            }
 721                        }
 722                        AssistantContext::Symbol(symbol_ctx) => {
 723                            log.buffer_added_as_context(
 724                                symbol_ctx.context_symbol.buffer.clone(),
 725                                cx,
 726                            );
 727                        }
 728                        AssistantContext::Excerpt(excerpt_context) => {
 729                            log.buffer_added_as_context(
 730                                excerpt_context.context_buffer.buffer.clone(),
 731                                cx,
 732                            );
 733                        }
 734                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 735                    }
 736                }
 737            });
 738        }
 739
 740        let context_ids = new_context
 741            .iter()
 742            .map(|context| context.id())
 743            .collect::<Vec<_>>();
 744        self.context.extend(
 745            new_context
 746                .into_iter()
 747                .map(|context| (context.id(), context)),
 748        );
 749        self.context_by_message.insert(message_id, context_ids);
 750
 751        if let Some(git_checkpoint) = git_checkpoint {
 752            self.pending_checkpoint = Some(ThreadCheckpoint {
 753                message_id,
 754                git_checkpoint,
 755            });
 756        }
 757
 758        self.auto_capture_telemetry(cx);
 759
 760        message_id
 761    }
 762
 763    pub fn insert_message(
 764        &mut self,
 765        role: Role,
 766        segments: Vec<MessageSegment>,
 767        cx: &mut Context<Self>,
 768    ) -> MessageId {
 769        let id = self.next_message_id.post_inc();
 770        self.messages.push(Message {
 771            id,
 772            role,
 773            segments,
 774            context: String::new(),
 775        });
 776        self.touch_updated_at();
 777        cx.emit(ThreadEvent::MessageAdded(id));
 778        id
 779    }
 780
 781    pub fn edit_message(
 782        &mut self,
 783        id: MessageId,
 784        new_role: Role,
 785        new_segments: Vec<MessageSegment>,
 786        cx: &mut Context<Self>,
 787    ) -> bool {
 788        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 789            return false;
 790        };
 791        message.role = new_role;
 792        message.segments = new_segments;
 793        self.touch_updated_at();
 794        cx.emit(ThreadEvent::MessageEdited(id));
 795        true
 796    }
 797
 798    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 799        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 800            return false;
 801        };
 802        self.messages.remove(index);
 803        self.context_by_message.remove(&id);
 804        self.touch_updated_at();
 805        cx.emit(ThreadEvent::MessageDeleted(id));
 806        true
 807    }
 808
 809    /// Returns the representation of this [`Thread`] in a textual form.
 810    ///
 811    /// This is the representation we use when attaching a thread as context to another thread.
 812    pub fn text(&self) -> String {
 813        let mut text = String::new();
 814
 815        for message in &self.messages {
 816            text.push_str(match message.role {
 817                language_model::Role::User => "User:",
 818                language_model::Role::Assistant => "Assistant:",
 819                language_model::Role::System => "System:",
 820            });
 821            text.push('\n');
 822
 823            for segment in &message.segments {
 824                match segment {
 825                    MessageSegment::Text(content) => text.push_str(content),
 826                    MessageSegment::Thinking(content) => {
 827                        text.push_str(&format!("<think>{}</think>", content))
 828                    }
 829                }
 830            }
 831            text.push('\n');
 832        }
 833
 834        text
 835    }
 836
 837    /// Serializes this thread into a format for storage or telemetry.
 838    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 839        let initial_project_snapshot = self.initial_project_snapshot.clone();
 840        cx.spawn(async move |this, cx| {
 841            let initial_project_snapshot = initial_project_snapshot.await;
 842            this.read_with(cx, |this, cx| SerializedThread {
 843                version: SerializedThread::VERSION.to_string(),
 844                summary: this.summary_or_default(),
 845                updated_at: this.updated_at(),
 846                messages: this
 847                    .messages()
 848                    .map(|message| SerializedMessage {
 849                        id: message.id,
 850                        role: message.role,
 851                        segments: message
 852                            .segments
 853                            .iter()
 854                            .map(|segment| match segment {
 855                                MessageSegment::Text(text) => {
 856                                    SerializedMessageSegment::Text { text: text.clone() }
 857                                }
 858                                MessageSegment::Thinking(text) => {
 859                                    SerializedMessageSegment::Thinking { text: text.clone() }
 860                                }
 861                            })
 862                            .collect(),
 863                        tool_uses: this
 864                            .tool_uses_for_message(message.id, cx)
 865                            .into_iter()
 866                            .map(|tool_use| SerializedToolUse {
 867                                id: tool_use.id,
 868                                name: tool_use.name,
 869                                input: tool_use.input,
 870                            })
 871                            .collect(),
 872                        tool_results: this
 873                            .tool_results_for_message(message.id)
 874                            .into_iter()
 875                            .map(|tool_result| SerializedToolResult {
 876                                tool_use_id: tool_result.tool_use_id.clone(),
 877                                is_error: tool_result.is_error,
 878                                content: tool_result.content.clone(),
 879                            })
 880                            .collect(),
 881                        context: message.context.clone(),
 882                    })
 883                    .collect(),
 884                initial_project_snapshot,
 885                cumulative_token_usage: this.cumulative_token_usage,
 886                request_token_usage: this.request_token_usage.clone(),
 887                detailed_summary_state: this.detailed_summary_state.clone(),
 888                exceeded_window_error: this.exceeded_window_error.clone(),
 889            })
 890        })
 891    }
 892
 893    pub fn send_to_model(
 894        &mut self,
 895        model: Arc<dyn LanguageModel>,
 896        request_kind: RequestKind,
 897        cx: &mut Context<Self>,
 898    ) {
 899        let mut request = self.to_completion_request(request_kind, cx);
 900        if model.supports_tools() {
 901            request.tools = {
 902                let mut tools = Vec::new();
 903                tools.extend(
 904                    self.tools()
 905                        .read(cx)
 906                        .enabled_tools(cx)
 907                        .into_iter()
 908                        .filter_map(|tool| {
 909                            // Skip tools that cannot be supported
 910                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 911                            Some(LanguageModelRequestTool {
 912                                name: tool.name(),
 913                                description: tool.description(),
 914                                input_schema,
 915                            })
 916                        }),
 917                );
 918
 919                tools
 920            };
 921        }
 922
 923        self.stream_completion(request, model, cx);
 924    }
 925
 926    pub fn used_tools_since_last_user_message(&self) -> bool {
 927        for message in self.messages.iter().rev() {
 928            if self.tool_use.message_has_tool_results(message.id) {
 929                return true;
 930            } else if message.role == Role::User {
 931                return false;
 932            }
 933        }
 934
 935        false
 936    }
 937
 938    pub fn to_completion_request(
 939        &self,
 940        request_kind: RequestKind,
 941        cx: &App,
 942    ) -> LanguageModelRequest {
 943        let mut request = LanguageModelRequest {
 944            messages: vec![],
 945            tools: Vec::new(),
 946            stop: Vec::new(),
 947            temperature: None,
 948        };
 949
 950        if let Some(project_context) = self.project_context.borrow().as_ref() {
 951            if let Some(system_prompt) = self
 952                .prompt_builder
 953                .generate_assistant_system_prompt(project_context)
 954                .context("failed to generate assistant system prompt")
 955                .log_err()
 956            {
 957                request.messages.push(LanguageModelRequestMessage {
 958                    role: Role::System,
 959                    content: vec![MessageContent::Text(system_prompt)],
 960                    cache: true,
 961                });
 962            }
 963        } else {
 964            log::error!("project_context not set.")
 965        }
 966
 967        for message in &self.messages {
 968            let mut request_message = LanguageModelRequestMessage {
 969                role: message.role,
 970                content: Vec::new(),
 971                cache: false,
 972            };
 973
 974            match request_kind {
 975                RequestKind::Chat => {
 976                    self.tool_use
 977                        .attach_tool_results(message.id, &mut request_message);
 978                }
 979                RequestKind::Summarize => {
 980                    // We don't care about tool use during summarization.
 981                    if self.tool_use.message_has_tool_results(message.id) {
 982                        continue;
 983                    }
 984                }
 985            }
 986
 987            if !message.segments.is_empty() {
 988                request_message
 989                    .content
 990                    .push(MessageContent::Text(message.to_string()));
 991            }
 992
 993            match request_kind {
 994                RequestKind::Chat => {
 995                    self.tool_use
 996                        .attach_tool_uses(message.id, &mut request_message);
 997                }
 998                RequestKind::Summarize => {
 999                    // We don't care about tool use during summarization.
1000                }
1001            };
1002
1003            request.messages.push(request_message);
1004        }
1005
1006        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1007        if let Some(last) = request.messages.last_mut() {
1008            last.cache = true;
1009        }
1010
1011        self.attached_tracked_files_state(&mut request.messages, cx);
1012
1013        request
1014    }
1015
1016    fn attached_tracked_files_state(
1017        &self,
1018        messages: &mut Vec<LanguageModelRequestMessage>,
1019        cx: &App,
1020    ) {
1021        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1022
1023        let mut stale_message = String::new();
1024
1025        let action_log = self.action_log.read(cx);
1026
1027        for stale_file in action_log.stale_buffers(cx) {
1028            let Some(file) = stale_file.read(cx).file() else {
1029                continue;
1030            };
1031
1032            if stale_message.is_empty() {
1033                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1034            }
1035
1036            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1037        }
1038
1039        let mut content = Vec::with_capacity(2);
1040
1041        if !stale_message.is_empty() {
1042            content.push(stale_message.into());
1043        }
1044
1045        if action_log.has_edited_files_since_project_diagnostics_check() {
1046            content.push(
1047                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1048                and fix all errors AND warnings you introduced! \
1049                DO NOT mention you're going to do this until you're done."
1050                    .into(),
1051            );
1052        }
1053
1054        if !content.is_empty() {
1055            let context_message = LanguageModelRequestMessage {
1056                role: Role::User,
1057                content,
1058                cache: false,
1059            };
1060
1061            messages.push(context_message);
1062        }
1063    }
1064
1065    pub fn stream_completion(
1066        &mut self,
1067        request: LanguageModelRequest,
1068        model: Arc<dyn LanguageModel>,
1069        cx: &mut Context<Self>,
1070    ) {
1071        let pending_completion_id = post_inc(&mut self.completion_count);
1072        let task = cx.spawn(async move |thread, cx| {
1073            let stream = model.stream_completion(request, &cx);
1074            let initial_token_usage =
1075                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1076            let stream_completion = async {
1077                let mut events = stream.await?;
1078                let mut stop_reason = StopReason::EndTurn;
1079                let mut current_token_usage = TokenUsage::default();
1080
1081                while let Some(event) = events.next().await {
1082                    let event = event?;
1083
1084                    thread.update(cx, |thread, cx| {
1085                        match event {
1086                            LanguageModelCompletionEvent::StartMessage { .. } => {
1087                                thread.insert_message(
1088                                    Role::Assistant,
1089                                    vec![MessageSegment::Text(String::new())],
1090                                    cx,
1091                                );
1092                            }
1093                            LanguageModelCompletionEvent::Stop(reason) => {
1094                                stop_reason = reason;
1095                            }
1096                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1097                                thread.update_token_usage_at_last_message(token_usage);
1098                                thread.cumulative_token_usage = thread.cumulative_token_usage
1099                                    + token_usage
1100                                    - current_token_usage;
1101                                current_token_usage = token_usage;
1102                            }
1103                            LanguageModelCompletionEvent::Text(chunk) => {
1104                                if let Some(last_message) = thread.messages.last_mut() {
1105                                    if last_message.role == Role::Assistant {
1106                                        last_message.push_text(&chunk);
1107                                        cx.emit(ThreadEvent::StreamedAssistantText(
1108                                            last_message.id,
1109                                            chunk,
1110                                        ));
1111                                    } else {
1112                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1113                                        // of a new Assistant response.
1114                                        //
1115                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1116                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1117                                        thread.insert_message(
1118                                            Role::Assistant,
1119                                            vec![MessageSegment::Text(chunk.to_string())],
1120                                            cx,
1121                                        );
1122                                    };
1123                                }
1124                            }
1125                            LanguageModelCompletionEvent::Thinking(chunk) => {
1126                                if let Some(last_message) = thread.messages.last_mut() {
1127                                    if last_message.role == Role::Assistant {
1128                                        last_message.push_thinking(&chunk);
1129                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1130                                            last_message.id,
1131                                            chunk,
1132                                        ));
1133                                    } else {
1134                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1135                                        // of a new Assistant response.
1136                                        //
1137                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1138                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1139                                        thread.insert_message(
1140                                            Role::Assistant,
1141                                            vec![MessageSegment::Thinking(chunk.to_string())],
1142                                            cx,
1143                                        );
1144                                    };
1145                                }
1146                            }
1147                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1148                                let last_assistant_message_id = thread
1149                                    .messages
1150                                    .iter_mut()
1151                                    .rfind(|message| message.role == Role::Assistant)
1152                                    .map(|message| message.id)
1153                                    .unwrap_or_else(|| {
1154                                        thread.insert_message(Role::Assistant, vec![], cx)
1155                                    });
1156
1157                                thread.tool_use.request_tool_use(
1158                                    last_assistant_message_id,
1159                                    tool_use,
1160                                    cx,
1161                                );
1162                            }
1163                        }
1164
1165                        thread.touch_updated_at();
1166                        cx.emit(ThreadEvent::StreamedCompletion);
1167                        cx.notify();
1168
1169                        thread.auto_capture_telemetry(cx);
1170                    })?;
1171
1172                    smol::future::yield_now().await;
1173                }
1174
1175                thread.update(cx, |thread, cx| {
1176                    thread
1177                        .pending_completions
1178                        .retain(|completion| completion.id != pending_completion_id);
1179
1180                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1181                        thread.summarize(cx);
1182                    }
1183                })?;
1184
1185                anyhow::Ok(stop_reason)
1186            };
1187
1188            let result = stream_completion.await;
1189
1190            thread
1191                .update(cx, |thread, cx| {
1192                    thread.finalize_pending_checkpoint(cx);
1193                    match result.as_ref() {
1194                        Ok(stop_reason) => match stop_reason {
1195                            StopReason::ToolUse => {
1196                                let tool_uses = thread.use_pending_tools(cx);
1197                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1198                            }
1199                            StopReason::EndTurn => {}
1200                            StopReason::MaxTokens => {}
1201                        },
1202                        Err(error) => {
1203                            if error.is::<PaymentRequiredError>() {
1204                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1205                            } else if error.is::<MaxMonthlySpendReachedError>() {
1206                                cx.emit(ThreadEvent::ShowError(
1207                                    ThreadError::MaxMonthlySpendReached,
1208                                ));
1209                            } else if let Some(error) =
1210                                error.downcast_ref::<ModelRequestLimitReachedError>()
1211                            {
1212                                cx.emit(ThreadEvent::ShowError(
1213                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1214                                ));
1215                            } else if let Some(known_error) =
1216                                error.downcast_ref::<LanguageModelKnownError>()
1217                            {
1218                                match known_error {
1219                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1220                                        tokens,
1221                                    } => {
1222                                        thread.exceeded_window_error = Some(ExceededWindowError {
1223                                            model_id: model.id(),
1224                                            token_count: *tokens,
1225                                        });
1226                                        cx.notify();
1227                                    }
1228                                }
1229                            } else {
1230                                let error_message = error
1231                                    .chain()
1232                                    .map(|err| err.to_string())
1233                                    .collect::<Vec<_>>()
1234                                    .join("\n");
1235                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1236                                    header: "Error interacting with language model".into(),
1237                                    message: SharedString::from(error_message.clone()),
1238                                }));
1239                            }
1240
1241                            thread.cancel_last_completion(cx);
1242                        }
1243                    }
1244                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1245
1246                    thread.auto_capture_telemetry(cx);
1247
1248                    if let Ok(initial_usage) = initial_token_usage {
1249                        let usage = thread.cumulative_token_usage - initial_usage;
1250
1251                        telemetry::event!(
1252                            "Assistant Thread Completion",
1253                            thread_id = thread.id().to_string(),
1254                            model = model.telemetry_id(),
1255                            model_provider = model.provider_id().to_string(),
1256                            input_tokens = usage.input_tokens,
1257                            output_tokens = usage.output_tokens,
1258                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1259                            cache_read_input_tokens = usage.cache_read_input_tokens,
1260                        );
1261                    }
1262                })
1263                .ok();
1264        });
1265
1266        self.pending_completions.push(PendingCompletion {
1267            id: pending_completion_id,
1268            _task: task,
1269        });
1270    }
1271
1272    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1273        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1274            return;
1275        };
1276
1277        if !model.provider.is_authenticated(cx) {
1278            return;
1279        }
1280
1281        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1282        request.messages.push(LanguageModelRequestMessage {
1283            role: Role::User,
1284            content: vec![
1285                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1286                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1287                 If the conversation is about a specific subject, include it in the title. \
1288                 Be descriptive. DO NOT speak in the first person."
1289                    .into(),
1290            ],
1291            cache: false,
1292        });
1293
1294        self.pending_summary = cx.spawn(async move |this, cx| {
1295            async move {
1296                let stream = model.model.stream_completion_text(request, &cx);
1297                let mut messages = stream.await?;
1298
1299                let mut new_summary = String::new();
1300                while let Some(message) = messages.stream.next().await {
1301                    let text = message?;
1302                    let mut lines = text.lines();
1303                    new_summary.extend(lines.next());
1304
1305                    // Stop if the LLM generated multiple lines.
1306                    if lines.next().is_some() {
1307                        break;
1308                    }
1309                }
1310
1311                this.update(cx, |this, cx| {
1312                    if !new_summary.is_empty() {
1313                        this.summary = Some(new_summary.into());
1314                    }
1315
1316                    cx.emit(ThreadEvent::SummaryGenerated);
1317                })?;
1318
1319                anyhow::Ok(())
1320            }
1321            .log_err()
1322            .await
1323        });
1324    }
1325
1326    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1327        let last_message_id = self.messages.last().map(|message| message.id)?;
1328
1329        match &self.detailed_summary_state {
1330            DetailedSummaryState::Generating { message_id, .. }
1331            | DetailedSummaryState::Generated { message_id, .. }
1332                if *message_id == last_message_id =>
1333            {
1334                // Already up-to-date
1335                return None;
1336            }
1337            _ => {}
1338        }
1339
1340        let ConfiguredModel { model, provider } =
1341            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1342
1343        if !provider.is_authenticated(cx) {
1344            return None;
1345        }
1346
1347        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1348
1349        request.messages.push(LanguageModelRequestMessage {
1350            role: Role::User,
1351            content: vec![
1352                "Generate a detailed summary of this conversation. Include:\n\
1353                1. A brief overview of what was discussed\n\
1354                2. Key facts or information discovered\n\
1355                3. Outcomes or conclusions reached\n\
1356                4. Any action items or next steps if any\n\
1357                Format it in Markdown with headings and bullet points."
1358                    .into(),
1359            ],
1360            cache: false,
1361        });
1362
1363        let task = cx.spawn(async move |thread, cx| {
1364            let stream = model.stream_completion_text(request, &cx);
1365            let Some(mut messages) = stream.await.log_err() else {
1366                thread
1367                    .update(cx, |this, _cx| {
1368                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1369                    })
1370                    .log_err();
1371
1372                return;
1373            };
1374
1375            let mut new_detailed_summary = String::new();
1376
1377            while let Some(chunk) = messages.stream.next().await {
1378                if let Some(chunk) = chunk.log_err() {
1379                    new_detailed_summary.push_str(&chunk);
1380                }
1381            }
1382
1383            thread
1384                .update(cx, |this, _cx| {
1385                    this.detailed_summary_state = DetailedSummaryState::Generated {
1386                        text: new_detailed_summary.into(),
1387                        message_id: last_message_id,
1388                    };
1389                })
1390                .log_err();
1391        });
1392
1393        self.detailed_summary_state = DetailedSummaryState::Generating {
1394            message_id: last_message_id,
1395        };
1396
1397        Some(task)
1398    }
1399
1400    pub fn is_generating_detailed_summary(&self) -> bool {
1401        matches!(
1402            self.detailed_summary_state,
1403            DetailedSummaryState::Generating { .. }
1404        )
1405    }
1406
1407    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1408        self.auto_capture_telemetry(cx);
1409        let request = self.to_completion_request(RequestKind::Chat, cx);
1410        let messages = Arc::new(request.messages);
1411        let pending_tool_uses = self
1412            .tool_use
1413            .pending_tool_uses()
1414            .into_iter()
1415            .filter(|tool_use| tool_use.status.is_idle())
1416            .cloned()
1417            .collect::<Vec<_>>();
1418
1419        for tool_use in pending_tool_uses.iter() {
1420            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1421                if tool.needs_confirmation(&tool_use.input, cx)
1422                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1423                {
1424                    self.tool_use.confirm_tool_use(
1425                        tool_use.id.clone(),
1426                        tool_use.ui_text.clone(),
1427                        tool_use.input.clone(),
1428                        messages.clone(),
1429                        tool,
1430                    );
1431                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1432                } else {
1433                    self.run_tool(
1434                        tool_use.id.clone(),
1435                        tool_use.ui_text.clone(),
1436                        tool_use.input.clone(),
1437                        &messages,
1438                        tool,
1439                        cx,
1440                    );
1441                }
1442            }
1443        }
1444
1445        pending_tool_uses
1446    }
1447
1448    pub fn run_tool(
1449        &mut self,
1450        tool_use_id: LanguageModelToolUseId,
1451        ui_text: impl Into<SharedString>,
1452        input: serde_json::Value,
1453        messages: &[LanguageModelRequestMessage],
1454        tool: Arc<dyn Tool>,
1455        cx: &mut Context<Thread>,
1456    ) {
1457        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1458        self.tool_use
1459            .run_pending_tool(tool_use_id, ui_text.into(), task);
1460    }
1461
1462    fn spawn_tool_use(
1463        &mut self,
1464        tool_use_id: LanguageModelToolUseId,
1465        messages: &[LanguageModelRequestMessage],
1466        input: serde_json::Value,
1467        tool: Arc<dyn Tool>,
1468        cx: &mut Context<Thread>,
1469    ) -> Task<()> {
1470        let tool_name: Arc<str> = tool.name().into();
1471
1472        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1473            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1474        } else {
1475            tool.run(
1476                input,
1477                messages,
1478                self.project.clone(),
1479                self.action_log.clone(),
1480                cx,
1481            )
1482        };
1483
1484        // Store the card separately if it exists
1485        if let Some(card) = tool_result.card.clone() {
1486            self.tool_use
1487                .insert_tool_result_card(tool_use_id.clone(), card);
1488        }
1489
1490        cx.spawn({
1491            async move |thread: WeakEntity<Thread>, cx| {
1492                let output = tool_result.output.await;
1493
1494                thread
1495                    .update(cx, |thread, cx| {
1496                        let pending_tool_use = thread.tool_use.insert_tool_output(
1497                            tool_use_id.clone(),
1498                            tool_name,
1499                            output,
1500                            cx,
1501                        );
1502                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1503                    })
1504                    .ok();
1505            }
1506        })
1507    }
1508
1509    fn tool_finished(
1510        &mut self,
1511        tool_use_id: LanguageModelToolUseId,
1512        pending_tool_use: Option<PendingToolUse>,
1513        canceled: bool,
1514        cx: &mut Context<Self>,
1515    ) {
1516        if self.all_tools_finished() {
1517            let model_registry = LanguageModelRegistry::read_global(cx);
1518            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1519                self.attach_tool_results(cx);
1520                if !canceled {
1521                    self.send_to_model(model, RequestKind::Chat, cx);
1522                }
1523            }
1524        }
1525
1526        cx.emit(ThreadEvent::ToolFinished {
1527            tool_use_id,
1528            pending_tool_use,
1529        });
1530    }
1531
1532    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1533        // Insert a user message to contain the tool results.
1534        self.insert_user_message(
1535            // TODO: Sending up a user message without any content results in the model sending back
1536            // responses that also don't have any content. We currently don't handle this case well,
1537            // so for now we provide some text to keep the model on track.
1538            "Here are the tool results.",
1539            Vec::new(),
1540            None,
1541            cx,
1542        );
1543    }
1544
1545    /// Cancels the last pending completion, if there are any pending.
1546    ///
1547    /// Returns whether a completion was canceled.
1548    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1549        let canceled = if self.pending_completions.pop().is_some() {
1550            true
1551        } else {
1552            let mut canceled = false;
1553            for pending_tool_use in self.tool_use.cancel_pending() {
1554                canceled = true;
1555                self.tool_finished(
1556                    pending_tool_use.id.clone(),
1557                    Some(pending_tool_use),
1558                    true,
1559                    cx,
1560                );
1561            }
1562            canceled
1563        };
1564        self.finalize_pending_checkpoint(cx);
1565        canceled
1566    }
1567
1568    pub fn feedback(&self) -> Option<ThreadFeedback> {
1569        self.feedback
1570    }
1571
1572    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1573        self.message_feedback.get(&message_id).copied()
1574    }
1575
1576    pub fn report_message_feedback(
1577        &mut self,
1578        message_id: MessageId,
1579        feedback: ThreadFeedback,
1580        cx: &mut Context<Self>,
1581    ) -> Task<Result<()>> {
1582        if self.message_feedback.get(&message_id) == Some(&feedback) {
1583            return Task::ready(Ok(()));
1584        }
1585
1586        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1587        let serialized_thread = self.serialize(cx);
1588        let thread_id = self.id().clone();
1589        let client = self.project.read(cx).client();
1590
1591        let enabled_tool_names: Vec<String> = self
1592            .tools()
1593            .read(cx)
1594            .enabled_tools(cx)
1595            .iter()
1596            .map(|tool| tool.name().to_string())
1597            .collect();
1598
1599        self.message_feedback.insert(message_id, feedback);
1600
1601        cx.notify();
1602
1603        let message_content = self
1604            .message(message_id)
1605            .map(|msg| msg.to_string())
1606            .unwrap_or_default();
1607
1608        cx.background_spawn(async move {
1609            let final_project_snapshot = final_project_snapshot.await;
1610            let serialized_thread = serialized_thread.await?;
1611            let thread_data =
1612                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1613
1614            let rating = match feedback {
1615                ThreadFeedback::Positive => "positive",
1616                ThreadFeedback::Negative => "negative",
1617            };
1618            telemetry::event!(
1619                "Assistant Thread Rated",
1620                rating,
1621                thread_id,
1622                enabled_tool_names,
1623                message_id = message_id.0,
1624                message_content,
1625                thread_data,
1626                final_project_snapshot
1627            );
1628            client.telemetry().flush_events();
1629
1630            Ok(())
1631        })
1632    }
1633
1634    pub fn report_feedback(
1635        &mut self,
1636        feedback: ThreadFeedback,
1637        cx: &mut Context<Self>,
1638    ) -> Task<Result<()>> {
1639        let last_assistant_message_id = self
1640            .messages
1641            .iter()
1642            .rev()
1643            .find(|msg| msg.role == Role::Assistant)
1644            .map(|msg| msg.id);
1645
1646        if let Some(message_id) = last_assistant_message_id {
1647            self.report_message_feedback(message_id, feedback, cx)
1648        } else {
1649            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1650            let serialized_thread = self.serialize(cx);
1651            let thread_id = self.id().clone();
1652            let client = self.project.read(cx).client();
1653            self.feedback = Some(feedback);
1654            cx.notify();
1655
1656            cx.background_spawn(async move {
1657                let final_project_snapshot = final_project_snapshot.await;
1658                let serialized_thread = serialized_thread.await?;
1659                let thread_data = serde_json::to_value(serialized_thread)
1660                    .unwrap_or_else(|_| serde_json::Value::Null);
1661
1662                let rating = match feedback {
1663                    ThreadFeedback::Positive => "positive",
1664                    ThreadFeedback::Negative => "negative",
1665                };
1666                telemetry::event!(
1667                    "Assistant Thread Rated",
1668                    rating,
1669                    thread_id,
1670                    thread_data,
1671                    final_project_snapshot
1672                );
1673                client.telemetry().flush_events();
1674
1675                Ok(())
1676            })
1677        }
1678    }
1679
1680    /// Create a snapshot of the current project state including git information and unsaved buffers.
1681    fn project_snapshot(
1682        project: Entity<Project>,
1683        cx: &mut Context<Self>,
1684    ) -> Task<Arc<ProjectSnapshot>> {
1685        let git_store = project.read(cx).git_store().clone();
1686        let worktree_snapshots: Vec<_> = project
1687            .read(cx)
1688            .visible_worktrees(cx)
1689            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1690            .collect();
1691
1692        cx.spawn(async move |_, cx| {
1693            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1694
1695            let mut unsaved_buffers = Vec::new();
1696            cx.update(|app_cx| {
1697                let buffer_store = project.read(app_cx).buffer_store();
1698                for buffer_handle in buffer_store.read(app_cx).buffers() {
1699                    let buffer = buffer_handle.read(app_cx);
1700                    if buffer.is_dirty() {
1701                        if let Some(file) = buffer.file() {
1702                            let path = file.path().to_string_lossy().to_string();
1703                            unsaved_buffers.push(path);
1704                        }
1705                    }
1706                }
1707            })
1708            .ok();
1709
1710            Arc::new(ProjectSnapshot {
1711                worktree_snapshots,
1712                unsaved_buffer_paths: unsaved_buffers,
1713                timestamp: Utc::now(),
1714            })
1715        })
1716    }
1717
1718    fn worktree_snapshot(
1719        worktree: Entity<project::Worktree>,
1720        git_store: Entity<GitStore>,
1721        cx: &App,
1722    ) -> Task<WorktreeSnapshot> {
1723        cx.spawn(async move |cx| {
1724            // Get worktree path and snapshot
1725            let worktree_info = cx.update(|app_cx| {
1726                let worktree = worktree.read(app_cx);
1727                let path = worktree.abs_path().to_string_lossy().to_string();
1728                let snapshot = worktree.snapshot();
1729                (path, snapshot)
1730            });
1731
1732            let Ok((worktree_path, _snapshot)) = worktree_info else {
1733                return WorktreeSnapshot {
1734                    worktree_path: String::new(),
1735                    git_state: None,
1736                };
1737            };
1738
1739            let git_state = git_store
1740                .update(cx, |git_store, cx| {
1741                    git_store
1742                        .repositories()
1743                        .values()
1744                        .find(|repo| {
1745                            repo.read(cx)
1746                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1747                                .is_some()
1748                        })
1749                        .cloned()
1750                })
1751                .ok()
1752                .flatten()
1753                .map(|repo| {
1754                    repo.update(cx, |repo, _| {
1755                        let current_branch =
1756                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1757                        repo.send_job(None, |state, _| async move {
1758                            let RepositoryState::Local { backend, .. } = state else {
1759                                return GitState {
1760                                    remote_url: None,
1761                                    head_sha: None,
1762                                    current_branch,
1763                                    diff: None,
1764                                };
1765                            };
1766
1767                            let remote_url = backend.remote_url("origin");
1768                            let head_sha = backend.head_sha();
1769                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1770
1771                            GitState {
1772                                remote_url,
1773                                head_sha,
1774                                current_branch,
1775                                diff,
1776                            }
1777                        })
1778                    })
1779                });
1780
1781            let git_state = match git_state {
1782                Some(git_state) => match git_state.ok() {
1783                    Some(git_state) => git_state.await.ok(),
1784                    None => None,
1785                },
1786                None => None,
1787            };
1788
1789            WorktreeSnapshot {
1790                worktree_path,
1791                git_state,
1792            }
1793        })
1794    }
1795
1796    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1797        let mut markdown = Vec::new();
1798
1799        if let Some(summary) = self.summary() {
1800            writeln!(markdown, "# {summary}\n")?;
1801        };
1802
1803        for message in self.messages() {
1804            writeln!(
1805                markdown,
1806                "## {role}\n",
1807                role = match message.role {
1808                    Role::User => "User",
1809                    Role::Assistant => "Assistant",
1810                    Role::System => "System",
1811                }
1812            )?;
1813
1814            if !message.context.is_empty() {
1815                writeln!(markdown, "{}", message.context)?;
1816            }
1817
1818            for segment in &message.segments {
1819                match segment {
1820                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1821                    MessageSegment::Thinking(text) => {
1822                        writeln!(markdown, "<think>{}</think>\n", text)?
1823                    }
1824                }
1825            }
1826
1827            for tool_use in self.tool_uses_for_message(message.id, cx) {
1828                writeln!(
1829                    markdown,
1830                    "**Use Tool: {} ({})**",
1831                    tool_use.name, tool_use.id
1832                )?;
1833                writeln!(markdown, "```json")?;
1834                writeln!(
1835                    markdown,
1836                    "{}",
1837                    serde_json::to_string_pretty(&tool_use.input)?
1838                )?;
1839                writeln!(markdown, "```")?;
1840            }
1841
1842            for tool_result in self.tool_results_for_message(message.id) {
1843                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1844                if tool_result.is_error {
1845                    write!(markdown, " (Error)")?;
1846                }
1847
1848                writeln!(markdown, "**\n")?;
1849                writeln!(markdown, "{}", tool_result.content)?;
1850            }
1851        }
1852
1853        Ok(String::from_utf8_lossy(&markdown).to_string())
1854    }
1855
1856    pub fn keep_edits_in_range(
1857        &mut self,
1858        buffer: Entity<language::Buffer>,
1859        buffer_range: Range<language::Anchor>,
1860        cx: &mut Context<Self>,
1861    ) {
1862        self.action_log.update(cx, |action_log, cx| {
1863            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1864        });
1865    }
1866
1867    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1868        self.action_log
1869            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1870    }
1871
1872    pub fn reject_edits_in_ranges(
1873        &mut self,
1874        buffer: Entity<language::Buffer>,
1875        buffer_ranges: Vec<Range<language::Anchor>>,
1876        cx: &mut Context<Self>,
1877    ) -> Task<Result<()>> {
1878        self.action_log.update(cx, |action_log, cx| {
1879            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1880        })
1881    }
1882
1883    pub fn action_log(&self) -> &Entity<ActionLog> {
1884        &self.action_log
1885    }
1886
1887    pub fn project(&self) -> &Entity<Project> {
1888        &self.project
1889    }
1890
1891    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1892        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1893            return;
1894        }
1895
1896        let now = Instant::now();
1897        if let Some(last) = self.last_auto_capture_at {
1898            if now.duration_since(last).as_secs() < 10 {
1899                return;
1900            }
1901        }
1902
1903        self.last_auto_capture_at = Some(now);
1904
1905        let thread_id = self.id().clone();
1906        let github_login = self
1907            .project
1908            .read(cx)
1909            .user_store()
1910            .read(cx)
1911            .current_user()
1912            .map(|user| user.github_login.clone());
1913        let client = self.project.read(cx).client().clone();
1914        let serialize_task = self.serialize(cx);
1915
1916        cx.background_executor()
1917            .spawn(async move {
1918                if let Ok(serialized_thread) = serialize_task.await {
1919                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1920                        telemetry::event!(
1921                            "Agent Thread Auto-Captured",
1922                            thread_id = thread_id.to_string(),
1923                            thread_data = thread_data,
1924                            auto_capture_reason = "tracked_user",
1925                            github_login = github_login
1926                        );
1927
1928                        client.telemetry().flush_events();
1929                    }
1930                }
1931            })
1932            .detach();
1933    }
1934
1935    pub fn cumulative_token_usage(&self) -> TokenUsage {
1936        self.cumulative_token_usage
1937    }
1938
1939    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1940        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1941            return TotalTokenUsage::default();
1942        };
1943
1944        let max = model.model.max_token_count();
1945
1946        let index = self
1947            .messages
1948            .iter()
1949            .position(|msg| msg.id == message_id)
1950            .unwrap_or(0);
1951
1952        if index == 0 {
1953            return TotalTokenUsage { total: 0, max };
1954        }
1955
1956        let token_usage = &self
1957            .request_token_usage
1958            .get(index - 1)
1959            .cloned()
1960            .unwrap_or_default();
1961
1962        TotalTokenUsage {
1963            total: token_usage.total_tokens() as usize,
1964            max,
1965        }
1966    }
1967
1968    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1969        let model_registry = LanguageModelRegistry::read_global(cx);
1970        let Some(model) = model_registry.default_model() else {
1971            return TotalTokenUsage::default();
1972        };
1973
1974        let max = model.model.max_token_count();
1975
1976        if let Some(exceeded_error) = &self.exceeded_window_error {
1977            if model.model.id() == exceeded_error.model_id {
1978                return TotalTokenUsage {
1979                    total: exceeded_error.token_count,
1980                    max,
1981                };
1982            }
1983        }
1984
1985        let total = self
1986            .token_usage_at_last_message()
1987            .unwrap_or_default()
1988            .total_tokens() as usize;
1989
1990        TotalTokenUsage { total, max }
1991    }
1992
1993    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
1994        self.request_token_usage
1995            .get(self.messages.len().saturating_sub(1))
1996            .or_else(|| self.request_token_usage.last())
1997            .cloned()
1998    }
1999
2000    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2001        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2002        self.request_token_usage
2003            .resize(self.messages.len(), placeholder);
2004
2005        if let Some(last) = self.request_token_usage.last_mut() {
2006            *last = token_usage;
2007        }
2008    }
2009
2010    pub fn deny_tool_use(
2011        &mut self,
2012        tool_use_id: LanguageModelToolUseId,
2013        tool_name: Arc<str>,
2014        cx: &mut Context<Self>,
2015    ) {
2016        let err = Err(anyhow::anyhow!(
2017            "Permission to run tool action denied by user"
2018        ));
2019
2020        self.tool_use
2021            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2022        self.tool_finished(tool_use_id.clone(), None, true, cx);
2023    }
2024}
2025
2026#[derive(Debug, Clone, Error)]
2027pub enum ThreadError {
2028    #[error("Payment required")]
2029    PaymentRequired,
2030    #[error("Max monthly spend reached")]
2031    MaxMonthlySpendReached,
2032    #[error("Model request limit reached")]
2033    ModelRequestLimitReached { plan: Plan },
2034    #[error("Message {header}: {message}")]
2035    Message {
2036        header: SharedString,
2037        message: SharedString,
2038    },
2039}
2040
2041#[derive(Debug, Clone)]
2042pub enum ThreadEvent {
2043    ShowError(ThreadError),
2044    StreamedCompletion,
2045    StreamedAssistantText(MessageId, String),
2046    StreamedAssistantThinking(MessageId, String),
2047    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2048    MessageAdded(MessageId),
2049    MessageEdited(MessageId),
2050    MessageDeleted(MessageId),
2051    SummaryGenerated,
2052    SummaryChanged,
2053    UsePendingTools {
2054        tool_uses: Vec<PendingToolUse>,
2055    },
2056    ToolFinished {
2057        #[allow(unused)]
2058        tool_use_id: LanguageModelToolUseId,
2059        /// The pending tool use that corresponds to this tool.
2060        pending_tool_use: Option<PendingToolUse>,
2061    },
2062    CheckpointChanged,
2063    ToolConfirmationNeeded,
2064}
2065
2066impl EventEmitter<ThreadEvent> for Thread {}
2067
2068struct PendingCompletion {
2069    id: usize,
2070    _task: Task<()>,
2071}
2072
2073#[cfg(test)]
2074mod tests {
2075    use super::*;
2076    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2077    use assistant_settings::AssistantSettings;
2078    use context_server::ContextServerSettings;
2079    use editor::EditorSettings;
2080    use gpui::TestAppContext;
2081    use project::{FakeFs, Project};
2082    use prompt_store::PromptBuilder;
2083    use serde_json::json;
2084    use settings::{Settings, SettingsStore};
2085    use std::sync::Arc;
2086    use theme::ThemeSettings;
2087    use util::path;
2088    use workspace::Workspace;
2089
2090    #[gpui::test]
2091    async fn test_message_with_context(cx: &mut TestAppContext) {
2092        init_test_settings(cx);
2093
2094        let project = create_test_project(
2095            cx,
2096            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2097        )
2098        .await;
2099
2100        let (_workspace, _thread_store, thread, context_store) =
2101            setup_test_environment(cx, project.clone()).await;
2102
2103        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2104            .await
2105            .unwrap();
2106
2107        let context =
2108            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2109
2110        // Insert user message with context
2111        let message_id = thread.update(cx, |thread, cx| {
2112            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2113        });
2114
2115        // Check content and context in message object
2116        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2117
2118        // Use different path format strings based on platform for the test
2119        #[cfg(windows)]
2120        let path_part = r"test\code.rs";
2121        #[cfg(not(windows))]
2122        let path_part = "test/code.rs";
2123
2124        let expected_context = format!(
2125            r#"
2126<context>
2127The following items were attached by the user. You don't need to use other tools to read them.
2128
2129<files>
2130```rs {path_part}
2131fn main() {{
2132    println!("Hello, world!");
2133}}
2134```
2135</files>
2136</context>
2137"#
2138        );
2139
2140        assert_eq!(message.role, Role::User);
2141        assert_eq!(message.segments.len(), 1);
2142        assert_eq!(
2143            message.segments[0],
2144            MessageSegment::Text("Please explain this code".to_string())
2145        );
2146        assert_eq!(message.context, expected_context);
2147
2148        // Check message in request
2149        let request = thread.read_with(cx, |thread, cx| {
2150            thread.to_completion_request(RequestKind::Chat, cx)
2151        });
2152
2153        assert_eq!(request.messages.len(), 2);
2154        let expected_full_message = format!("{}Please explain this code", expected_context);
2155        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2156    }
2157
2158    #[gpui::test]
2159    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2160        init_test_settings(cx);
2161
2162        let project = create_test_project(
2163            cx,
2164            json!({
2165                "file1.rs": "fn function1() {}\n",
2166                "file2.rs": "fn function2() {}\n",
2167                "file3.rs": "fn function3() {}\n",
2168            }),
2169        )
2170        .await;
2171
2172        let (_, _thread_store, thread, context_store) =
2173            setup_test_environment(cx, project.clone()).await;
2174
2175        // Open files individually
2176        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2177            .await
2178            .unwrap();
2179        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2180            .await
2181            .unwrap();
2182        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2183            .await
2184            .unwrap();
2185
2186        // Get the context objects
2187        let contexts = context_store.update(cx, |store, _| store.context().clone());
2188        assert_eq!(contexts.len(), 3);
2189
2190        // First message with context 1
2191        let message1_id = thread.update(cx, |thread, cx| {
2192            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2193        });
2194
2195        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2196        let message2_id = thread.update(cx, |thread, cx| {
2197            thread.insert_user_message(
2198                "Message 2",
2199                vec![contexts[0].clone(), contexts[1].clone()],
2200                None,
2201                cx,
2202            )
2203        });
2204
2205        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2206        let message3_id = thread.update(cx, |thread, cx| {
2207            thread.insert_user_message(
2208                "Message 3",
2209                vec![
2210                    contexts[0].clone(),
2211                    contexts[1].clone(),
2212                    contexts[2].clone(),
2213                ],
2214                None,
2215                cx,
2216            )
2217        });
2218
2219        // Check what contexts are included in each message
2220        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2221            (
2222                thread.message(message1_id).unwrap().clone(),
2223                thread.message(message2_id).unwrap().clone(),
2224                thread.message(message3_id).unwrap().clone(),
2225            )
2226        });
2227
2228        // First message should include context 1
2229        assert!(message1.context.contains("file1.rs"));
2230
2231        // Second message should include only context 2 (not 1)
2232        assert!(!message2.context.contains("file1.rs"));
2233        assert!(message2.context.contains("file2.rs"));
2234
2235        // Third message should include only context 3 (not 1 or 2)
2236        assert!(!message3.context.contains("file1.rs"));
2237        assert!(!message3.context.contains("file2.rs"));
2238        assert!(message3.context.contains("file3.rs"));
2239
2240        // Check entire request to make sure all contexts are properly included
2241        let request = thread.read_with(cx, |thread, cx| {
2242            thread.to_completion_request(RequestKind::Chat, cx)
2243        });
2244
2245        // The request should contain all 3 messages
2246        assert_eq!(request.messages.len(), 4);
2247
2248        // Check that the contexts are properly formatted in each message
2249        assert!(request.messages[1].string_contents().contains("file1.rs"));
2250        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2251        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2252
2253        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2254        assert!(request.messages[2].string_contents().contains("file2.rs"));
2255        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2256
2257        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2258        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2259        assert!(request.messages[3].string_contents().contains("file3.rs"));
2260    }
2261
2262    #[gpui::test]
2263    async fn test_message_without_files(cx: &mut TestAppContext) {
2264        init_test_settings(cx);
2265
2266        let project = create_test_project(
2267            cx,
2268            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2269        )
2270        .await;
2271
2272        let (_, _thread_store, thread, _context_store) =
2273            setup_test_environment(cx, project.clone()).await;
2274
2275        // Insert user message without any context (empty context vector)
2276        let message_id = thread.update(cx, |thread, cx| {
2277            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2278        });
2279
2280        // Check content and context in message object
2281        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2282
2283        // Context should be empty when no files are included
2284        assert_eq!(message.role, Role::User);
2285        assert_eq!(message.segments.len(), 1);
2286        assert_eq!(
2287            message.segments[0],
2288            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2289        );
2290        assert_eq!(message.context, "");
2291
2292        // Check message in request
2293        let request = thread.read_with(cx, |thread, cx| {
2294            thread.to_completion_request(RequestKind::Chat, cx)
2295        });
2296
2297        assert_eq!(request.messages.len(), 2);
2298        assert_eq!(
2299            request.messages[1].string_contents(),
2300            "What is the best way to learn Rust?"
2301        );
2302
2303        // Add second message, also without context
2304        let message2_id = thread.update(cx, |thread, cx| {
2305            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2306        });
2307
2308        let message2 =
2309            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2310        assert_eq!(message2.context, "");
2311
2312        // Check that both messages appear in the request
2313        let request = thread.read_with(cx, |thread, cx| {
2314            thread.to_completion_request(RequestKind::Chat, cx)
2315        });
2316
2317        assert_eq!(request.messages.len(), 3);
2318        assert_eq!(
2319            request.messages[1].string_contents(),
2320            "What is the best way to learn Rust?"
2321        );
2322        assert_eq!(
2323            request.messages[2].string_contents(),
2324            "Are there any good books?"
2325        );
2326    }
2327
2328    #[gpui::test]
2329    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2330        init_test_settings(cx);
2331
2332        let project = create_test_project(
2333            cx,
2334            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2335        )
2336        .await;
2337
2338        let (_workspace, _thread_store, thread, context_store) =
2339            setup_test_environment(cx, project.clone()).await;
2340
2341        // Open buffer and add it to context
2342        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2343            .await
2344            .unwrap();
2345
2346        let context =
2347            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2348
2349        // Insert user message with the buffer as context
2350        thread.update(cx, |thread, cx| {
2351            thread.insert_user_message("Explain this code", vec![context], None, cx)
2352        });
2353
2354        // Create a request and check that it doesn't have a stale buffer warning yet
2355        let initial_request = thread.read_with(cx, |thread, cx| {
2356            thread.to_completion_request(RequestKind::Chat, cx)
2357        });
2358
2359        // Make sure we don't have a stale file warning yet
2360        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2361            msg.string_contents()
2362                .contains("These files changed since last read:")
2363        });
2364        assert!(
2365            !has_stale_warning,
2366            "Should not have stale buffer warning before buffer is modified"
2367        );
2368
2369        // Modify the buffer
2370        buffer.update(cx, |buffer, cx| {
2371            // Find a position at the end of line 1
2372            buffer.edit(
2373                [(1..1, "\n    println!(\"Added a new line\");\n")],
2374                None,
2375                cx,
2376            );
2377        });
2378
2379        // Insert another user message without context
2380        thread.update(cx, |thread, cx| {
2381            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2382        });
2383
2384        // Create a new request and check for the stale buffer warning
2385        let new_request = thread.read_with(cx, |thread, cx| {
2386            thread.to_completion_request(RequestKind::Chat, cx)
2387        });
2388
2389        // We should have a stale file warning as the last message
2390        let last_message = new_request
2391            .messages
2392            .last()
2393            .expect("Request should have messages");
2394
2395        // The last message should be the stale buffer notification
2396        assert_eq!(last_message.role, Role::User);
2397
2398        // Check the exact content of the message
2399        let expected_content = "These files changed since last read:\n- code.rs\n";
2400        assert_eq!(
2401            last_message.string_contents(),
2402            expected_content,
2403            "Last message should be exactly the stale buffer notification"
2404        );
2405    }
2406
2407    fn init_test_settings(cx: &mut TestAppContext) {
2408        cx.update(|cx| {
2409            let settings_store = SettingsStore::test(cx);
2410            cx.set_global(settings_store);
2411            language::init(cx);
2412            Project::init_settings(cx);
2413            AssistantSettings::register(cx);
2414            thread_store::init(cx);
2415            workspace::init_settings(cx);
2416            ThemeSettings::register(cx);
2417            ContextServerSettings::register(cx);
2418            EditorSettings::register(cx);
2419        });
2420    }
2421
2422    // Helper to create a test project with test files
2423    async fn create_test_project(
2424        cx: &mut TestAppContext,
2425        files: serde_json::Value,
2426    ) -> Entity<Project> {
2427        let fs = FakeFs::new(cx.executor());
2428        fs.insert_tree(path!("/test"), files).await;
2429        Project::test(fs, [path!("/test").as_ref()], cx).await
2430    }
2431
2432    async fn setup_test_environment(
2433        cx: &mut TestAppContext,
2434        project: Entity<Project>,
2435    ) -> (
2436        Entity<Workspace>,
2437        Entity<ThreadStore>,
2438        Entity<Thread>,
2439        Entity<ContextStore>,
2440    ) {
2441        let (workspace, cx) =
2442            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2443
2444        let thread_store = cx
2445            .update(|_, cx| {
2446                ThreadStore::load(
2447                    project.clone(),
2448                    cx.new(|_| ToolWorkingSet::default()),
2449                    Arc::new(PromptBuilder::new(None).unwrap()),
2450                    cx,
2451                )
2452            })
2453            .await;
2454
2455        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2456        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2457
2458        (workspace, thread_store, thread, context_store)
2459    }
2460
2461    async fn add_file_to_context(
2462        project: &Entity<Project>,
2463        context_store: &Entity<ContextStore>,
2464        path: &str,
2465        cx: &mut TestAppContext,
2466    ) -> Result<Entity<language::Buffer>> {
2467        let buffer_path = project
2468            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2469            .unwrap();
2470
2471        let buffer = project
2472            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2473            .await
2474            .unwrap();
2475
2476        context_store
2477            .update(cx, |store, cx| {
2478                store.add_file_from_buffer(buffer.clone(), cx)
2479            })
2480            .await?;
2481
2482        Ok(buffer)
2483    }
2484}