thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
  23};
  24use project::Project;
  25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  26use prompt_store::PromptBuilder;
  27use proto::Plan;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings;
  31use thiserror::Error;
  32use util::{ResultExt as _, TryFutureExt as _, post_inc};
  33use uuid::Uuid;
  34use zed_llm_client::UsageLimit;
  35
  36use crate::context::{AssistantContext, ContextId, format_context_as_string};
  37use crate::thread_store::{
  38    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  39    SerializedToolUse, SharedProjectContext,
  40};
  41use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  42
  43#[derive(Debug, Clone, Copy)]
  44pub enum RequestKind {
  45    Chat,
  46    /// Used when summarizing a thread.
  47    Summarize,
  48}
  49
  50#[derive(
  51    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  52)]
  53pub struct ThreadId(Arc<str>);
  54
  55impl ThreadId {
  56    pub fn new() -> Self {
  57        Self(Uuid::new_v4().to_string().into())
  58    }
  59}
  60
  61impl std::fmt::Display for ThreadId {
  62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  63        write!(f, "{}", self.0)
  64    }
  65}
  66
  67impl From<&str> for ThreadId {
  68    fn from(value: &str) -> Self {
  69        Self(value.into())
  70    }
  71}
  72
  73#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  74pub struct MessageId(pub(crate) usize);
  75
  76impl MessageId {
  77    fn post_inc(&mut self) -> Self {
  78        Self(post_inc(&mut self.0))
  79    }
  80}
  81
  82/// A message in a [`Thread`].
  83#[derive(Debug, Clone)]
  84pub struct Message {
  85    pub id: MessageId,
  86    pub role: Role,
  87    pub segments: Vec<MessageSegment>,
  88    pub context: String,
  89}
  90
  91impl Message {
  92    /// Returns whether the message contains any meaningful text that should be displayed
  93    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  94    pub fn should_display_content(&self) -> bool {
  95        self.segments.iter().all(|segment| segment.should_display())
  96    }
  97
  98    pub fn push_thinking(&mut self, text: &str) {
  99        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
 100            segment.push_str(text);
 101        } else {
 102            self.segments
 103                .push(MessageSegment::Thinking(text.to_string()));
 104        }
 105    }
 106
 107    pub fn push_text(&mut self, text: &str) {
 108        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 109            segment.push_str(text);
 110        } else {
 111            self.segments.push(MessageSegment::Text(text.to_string()));
 112        }
 113    }
 114
 115    pub fn to_string(&self) -> String {
 116        let mut result = String::new();
 117
 118        if !self.context.is_empty() {
 119            result.push_str(&self.context);
 120        }
 121
 122        for segment in &self.segments {
 123            match segment {
 124                MessageSegment::Text(text) => result.push_str(text),
 125                MessageSegment::Thinking(text) => {
 126                    result.push_str("<think>");
 127                    result.push_str(text);
 128                    result.push_str("</think>");
 129                }
 130            }
 131        }
 132
 133        result
 134    }
 135}
 136
 137#[derive(Debug, Clone, PartialEq, Eq)]
 138pub enum MessageSegment {
 139    Text(String),
 140    Thinking(String),
 141}
 142
 143impl MessageSegment {
 144    pub fn text_mut(&mut self) -> &mut String {
 145        match self {
 146            Self::Text(text) => text,
 147            Self::Thinking(text) => text,
 148        }
 149    }
 150
 151    pub fn should_display(&self) -> bool {
 152        // We add USING_TOOL_MARKER when making a request that includes tool uses
 153        // without non-whitespace text around them, and this can cause the model
 154        // to mimic the pattern, so we consider those segments not displayable.
 155        match self {
 156            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 157            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 158        }
 159    }
 160}
 161
 162#[derive(Debug, Clone, Serialize, Deserialize)]
 163pub struct ProjectSnapshot {
 164    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 165    pub unsaved_buffer_paths: Vec<String>,
 166    pub timestamp: DateTime<Utc>,
 167}
 168
 169#[derive(Debug, Clone, Serialize, Deserialize)]
 170pub struct WorktreeSnapshot {
 171    pub worktree_path: String,
 172    pub git_state: Option<GitState>,
 173}
 174
 175#[derive(Debug, Clone, Serialize, Deserialize)]
 176pub struct GitState {
 177    pub remote_url: Option<String>,
 178    pub head_sha: Option<String>,
 179    pub current_branch: Option<String>,
 180    pub diff: Option<String>,
 181}
 182
 183#[derive(Clone)]
 184pub struct ThreadCheckpoint {
 185    message_id: MessageId,
 186    git_checkpoint: GitStoreCheckpoint,
 187}
 188
 189#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 190pub enum ThreadFeedback {
 191    Positive,
 192    Negative,
 193}
 194
 195pub enum LastRestoreCheckpoint {
 196    Pending {
 197        message_id: MessageId,
 198    },
 199    Error {
 200        message_id: MessageId,
 201        error: String,
 202    },
 203}
 204
 205impl LastRestoreCheckpoint {
 206    pub fn message_id(&self) -> MessageId {
 207        match self {
 208            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 209            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 210        }
 211    }
 212}
 213
 214#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 215pub enum DetailedSummaryState {
 216    #[default]
 217    NotGenerated,
 218    Generating {
 219        message_id: MessageId,
 220    },
 221    Generated {
 222        text: SharedString,
 223        message_id: MessageId,
 224    },
 225}
 226
 227#[derive(Default)]
 228pub struct TotalTokenUsage {
 229    pub total: usize,
 230    pub max: usize,
 231}
 232
 233impl TotalTokenUsage {
 234    pub fn ratio(&self) -> TokenUsageRatio {
 235        #[cfg(debug_assertions)]
 236        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 237            .unwrap_or("0.8".to_string())
 238            .parse()
 239            .unwrap();
 240        #[cfg(not(debug_assertions))]
 241        let warning_threshold: f32 = 0.8;
 242
 243        if self.total >= self.max {
 244            TokenUsageRatio::Exceeded
 245        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 246            TokenUsageRatio::Warning
 247        } else {
 248            TokenUsageRatio::Normal
 249        }
 250    }
 251
 252    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 253        TotalTokenUsage {
 254            total: self.total + tokens,
 255            max: self.max,
 256        }
 257    }
 258}
 259
 260#[derive(Debug, Default, PartialEq, Eq)]
 261pub enum TokenUsageRatio {
 262    #[default]
 263    Normal,
 264    Warning,
 265    Exceeded,
 266}
 267
 268/// A thread of conversation with the LLM.
 269pub struct Thread {
 270    id: ThreadId,
 271    updated_at: DateTime<Utc>,
 272    summary: Option<SharedString>,
 273    pending_summary: Task<Option<()>>,
 274    detailed_summary_state: DetailedSummaryState,
 275    messages: Vec<Message>,
 276    next_message_id: MessageId,
 277    context: BTreeMap<ContextId, AssistantContext>,
 278    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 279    project_context: SharedProjectContext,
 280    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 281    completion_count: usize,
 282    pending_completions: Vec<PendingCompletion>,
 283    project: Entity<Project>,
 284    prompt_builder: Arc<PromptBuilder>,
 285    tools: Entity<ToolWorkingSet>,
 286    tool_use: ToolUseState,
 287    action_log: Entity<ActionLog>,
 288    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 289    pending_checkpoint: Option<ThreadCheckpoint>,
 290    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 291    request_token_usage: Vec<TokenUsage>,
 292    cumulative_token_usage: TokenUsage,
 293    exceeded_window_error: Option<ExceededWindowError>,
 294    feedback: Option<ThreadFeedback>,
 295    message_feedback: HashMap<MessageId, ThreadFeedback>,
 296    last_auto_capture_at: Option<Instant>,
 297}
 298
 299#[derive(Debug, Clone, Serialize, Deserialize)]
 300pub struct ExceededWindowError {
 301    /// Model used when last message exceeded context window
 302    model_id: LanguageModelId,
 303    /// Token count including last message
 304    token_count: usize,
 305}
 306
 307impl Thread {
 308    pub fn new(
 309        project: Entity<Project>,
 310        tools: Entity<ToolWorkingSet>,
 311        prompt_builder: Arc<PromptBuilder>,
 312        system_prompt: SharedProjectContext,
 313        cx: &mut Context<Self>,
 314    ) -> Self {
 315        Self {
 316            id: ThreadId::new(),
 317            updated_at: Utc::now(),
 318            summary: None,
 319            pending_summary: Task::ready(None),
 320            detailed_summary_state: DetailedSummaryState::NotGenerated,
 321            messages: Vec::new(),
 322            next_message_id: MessageId(0),
 323            context: BTreeMap::default(),
 324            context_by_message: HashMap::default(),
 325            project_context: system_prompt,
 326            checkpoints_by_message: HashMap::default(),
 327            completion_count: 0,
 328            pending_completions: Vec::new(),
 329            project: project.clone(),
 330            prompt_builder,
 331            tools: tools.clone(),
 332            last_restore_checkpoint: None,
 333            pending_checkpoint: None,
 334            tool_use: ToolUseState::new(tools.clone()),
 335            action_log: cx.new(|_| ActionLog::new(project.clone())),
 336            initial_project_snapshot: {
 337                let project_snapshot = Self::project_snapshot(project, cx);
 338                cx.foreground_executor()
 339                    .spawn(async move { Some(project_snapshot.await) })
 340                    .shared()
 341            },
 342            request_token_usage: Vec::new(),
 343            cumulative_token_usage: TokenUsage::default(),
 344            exceeded_window_error: None,
 345            feedback: None,
 346            message_feedback: HashMap::default(),
 347            last_auto_capture_at: None,
 348        }
 349    }
 350
 351    pub fn deserialize(
 352        id: ThreadId,
 353        serialized: SerializedThread,
 354        project: Entity<Project>,
 355        tools: Entity<ToolWorkingSet>,
 356        prompt_builder: Arc<PromptBuilder>,
 357        project_context: SharedProjectContext,
 358        cx: &mut Context<Self>,
 359    ) -> Self {
 360        let next_message_id = MessageId(
 361            serialized
 362                .messages
 363                .last()
 364                .map(|message| message.id.0 + 1)
 365                .unwrap_or(0),
 366        );
 367        let tool_use =
 368            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 369
 370        Self {
 371            id,
 372            updated_at: serialized.updated_at,
 373            summary: Some(serialized.summary),
 374            pending_summary: Task::ready(None),
 375            detailed_summary_state: serialized.detailed_summary_state,
 376            messages: serialized
 377                .messages
 378                .into_iter()
 379                .map(|message| Message {
 380                    id: message.id,
 381                    role: message.role,
 382                    segments: message
 383                        .segments
 384                        .into_iter()
 385                        .map(|segment| match segment {
 386                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 387                            SerializedMessageSegment::Thinking { text } => {
 388                                MessageSegment::Thinking(text)
 389                            }
 390                        })
 391                        .collect(),
 392                    context: message.context,
 393                })
 394                .collect(),
 395            next_message_id,
 396            context: BTreeMap::default(),
 397            context_by_message: HashMap::default(),
 398            project_context,
 399            checkpoints_by_message: HashMap::default(),
 400            completion_count: 0,
 401            pending_completions: Vec::new(),
 402            last_restore_checkpoint: None,
 403            pending_checkpoint: None,
 404            project: project.clone(),
 405            prompt_builder,
 406            tools,
 407            tool_use,
 408            action_log: cx.new(|_| ActionLog::new(project)),
 409            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 410            request_token_usage: serialized.request_token_usage,
 411            cumulative_token_usage: serialized.cumulative_token_usage,
 412            exceeded_window_error: None,
 413            feedback: None,
 414            message_feedback: HashMap::default(),
 415            last_auto_capture_at: None,
 416        }
 417    }
 418
 419    pub fn id(&self) -> &ThreadId {
 420        &self.id
 421    }
 422
 423    pub fn is_empty(&self) -> bool {
 424        self.messages.is_empty()
 425    }
 426
 427    pub fn updated_at(&self) -> DateTime<Utc> {
 428        self.updated_at
 429    }
 430
 431    pub fn touch_updated_at(&mut self) {
 432        self.updated_at = Utc::now();
 433    }
 434
 435    pub fn summary(&self) -> Option<SharedString> {
 436        self.summary.clone()
 437    }
 438
 439    pub fn project_context(&self) -> SharedProjectContext {
 440        self.project_context.clone()
 441    }
 442
 443    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 444
 445    pub fn summary_or_default(&self) -> SharedString {
 446        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 447    }
 448
 449    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 450        let Some(current_summary) = &self.summary else {
 451            // Don't allow setting summary until generated
 452            return;
 453        };
 454
 455        let mut new_summary = new_summary.into();
 456
 457        if new_summary.is_empty() {
 458            new_summary = Self::DEFAULT_SUMMARY;
 459        }
 460
 461        if current_summary != &new_summary {
 462            self.summary = Some(new_summary);
 463            cx.emit(ThreadEvent::SummaryChanged);
 464        }
 465    }
 466
 467    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 468        self.latest_detailed_summary()
 469            .unwrap_or_else(|| self.text().into())
 470    }
 471
 472    fn latest_detailed_summary(&self) -> Option<SharedString> {
 473        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 474            Some(text.clone())
 475        } else {
 476            None
 477        }
 478    }
 479
 480    pub fn message(&self, id: MessageId) -> Option<&Message> {
 481        self.messages.iter().find(|message| message.id == id)
 482    }
 483
 484    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 485        self.messages.iter()
 486    }
 487
 488    pub fn is_generating(&self) -> bool {
 489        !self.pending_completions.is_empty() || !self.all_tools_finished()
 490    }
 491
 492    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 493        &self.tools
 494    }
 495
 496    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 497        self.tool_use
 498            .pending_tool_uses()
 499            .into_iter()
 500            .find(|tool_use| &tool_use.id == id)
 501    }
 502
 503    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 504        self.tool_use
 505            .pending_tool_uses()
 506            .into_iter()
 507            .filter(|tool_use| tool_use.status.needs_confirmation())
 508    }
 509
 510    pub fn has_pending_tool_uses(&self) -> bool {
 511        !self.tool_use.pending_tool_uses().is_empty()
 512    }
 513
 514    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 515        self.checkpoints_by_message.get(&id).cloned()
 516    }
 517
 518    pub fn restore_checkpoint(
 519        &mut self,
 520        checkpoint: ThreadCheckpoint,
 521        cx: &mut Context<Self>,
 522    ) -> Task<Result<()>> {
 523        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 524            message_id: checkpoint.message_id,
 525        });
 526        cx.emit(ThreadEvent::CheckpointChanged);
 527        cx.notify();
 528
 529        let git_store = self.project().read(cx).git_store().clone();
 530        let restore = git_store.update(cx, |git_store, cx| {
 531            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 532        });
 533
 534        cx.spawn(async move |this, cx| {
 535            let result = restore.await;
 536            this.update(cx, |this, cx| {
 537                if let Err(err) = result.as_ref() {
 538                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 539                        message_id: checkpoint.message_id,
 540                        error: err.to_string(),
 541                    });
 542                } else {
 543                    this.truncate(checkpoint.message_id, cx);
 544                    this.last_restore_checkpoint = None;
 545                }
 546                this.pending_checkpoint = None;
 547                cx.emit(ThreadEvent::CheckpointChanged);
 548                cx.notify();
 549            })?;
 550            result
 551        })
 552    }
 553
 554    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 555        let pending_checkpoint = if self.is_generating() {
 556            return;
 557        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 558            checkpoint
 559        } else {
 560            return;
 561        };
 562
 563        let git_store = self.project.read(cx).git_store().clone();
 564        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 565        cx.spawn(async move |this, cx| match final_checkpoint.await {
 566            Ok(final_checkpoint) => {
 567                let equal = git_store
 568                    .update(cx, |store, cx| {
 569                        store.compare_checkpoints(
 570                            pending_checkpoint.git_checkpoint.clone(),
 571                            final_checkpoint.clone(),
 572                            cx,
 573                        )
 574                    })?
 575                    .await
 576                    .unwrap_or(false);
 577
 578                if equal {
 579                    git_store
 580                        .update(cx, |store, cx| {
 581                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 582                        })?
 583                        .detach();
 584                } else {
 585                    this.update(cx, |this, cx| {
 586                        this.insert_checkpoint(pending_checkpoint, cx)
 587                    })?;
 588                }
 589
 590                git_store
 591                    .update(cx, |store, cx| {
 592                        store.delete_checkpoint(final_checkpoint, cx)
 593                    })?
 594                    .detach();
 595
 596                Ok(())
 597            }
 598            Err(_) => this.update(cx, |this, cx| {
 599                this.insert_checkpoint(pending_checkpoint, cx)
 600            }),
 601        })
 602        .detach();
 603    }
 604
 605    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 606        self.checkpoints_by_message
 607            .insert(checkpoint.message_id, checkpoint);
 608        cx.emit(ThreadEvent::CheckpointChanged);
 609        cx.notify();
 610    }
 611
 612    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 613        self.last_restore_checkpoint.as_ref()
 614    }
 615
 616    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 617        let Some(message_ix) = self
 618            .messages
 619            .iter()
 620            .rposition(|message| message.id == message_id)
 621        else {
 622            return;
 623        };
 624        for deleted_message in self.messages.drain(message_ix..) {
 625            self.context_by_message.remove(&deleted_message.id);
 626            self.checkpoints_by_message.remove(&deleted_message.id);
 627        }
 628        cx.notify();
 629    }
 630
 631    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 632        self.context_by_message
 633            .get(&id)
 634            .into_iter()
 635            .flat_map(|context| {
 636                context
 637                    .iter()
 638                    .filter_map(|context_id| self.context.get(&context_id))
 639            })
 640    }
 641
 642    /// Returns whether all of the tool uses have finished running.
 643    pub fn all_tools_finished(&self) -> bool {
 644        // If the only pending tool uses left are the ones with errors, then
 645        // that means that we've finished running all of the pending tools.
 646        self.tool_use
 647            .pending_tool_uses()
 648            .iter()
 649            .all(|tool_use| tool_use.status.is_error())
 650    }
 651
 652    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 653        self.tool_use.tool_uses_for_message(id, cx)
 654    }
 655
 656    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 657        self.tool_use.tool_results_for_message(id)
 658    }
 659
 660    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 661        self.tool_use.tool_result(id)
 662    }
 663
 664    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 665        Some(&self.tool_use.tool_result(id)?.content)
 666    }
 667
 668    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 669        self.tool_use.tool_result_card(id).cloned()
 670    }
 671
 672    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 673        self.tool_use.message_has_tool_results(message_id)
 674    }
 675
 676    /// Filter out contexts that have already been included in previous messages
 677    pub fn filter_new_context<'a>(
 678        &self,
 679        context: impl Iterator<Item = &'a AssistantContext>,
 680    ) -> impl Iterator<Item = &'a AssistantContext> {
 681        context.filter(|ctx| self.is_context_new(ctx))
 682    }
 683
 684    fn is_context_new(&self, context: &AssistantContext) -> bool {
 685        !self.context.contains_key(&context.id())
 686    }
 687
 688    pub fn insert_user_message(
 689        &mut self,
 690        text: impl Into<String>,
 691        context: Vec<AssistantContext>,
 692        git_checkpoint: Option<GitStoreCheckpoint>,
 693        cx: &mut Context<Self>,
 694    ) -> MessageId {
 695        let text = text.into();
 696
 697        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 698
 699        let new_context: Vec<_> = context
 700            .into_iter()
 701            .filter(|ctx| self.is_context_new(ctx))
 702            .collect();
 703
 704        if !new_context.is_empty() {
 705            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 706                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 707                    message.context = context_string;
 708                }
 709            }
 710
 711            self.action_log.update(cx, |log, cx| {
 712                // Track all buffers added as context
 713                for ctx in &new_context {
 714                    match ctx {
 715                        AssistantContext::File(file_ctx) => {
 716                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 717                        }
 718                        AssistantContext::Directory(dir_ctx) => {
 719                            for context_buffer in &dir_ctx.context_buffers {
 720                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 721                            }
 722                        }
 723                        AssistantContext::Symbol(symbol_ctx) => {
 724                            log.buffer_added_as_context(
 725                                symbol_ctx.context_symbol.buffer.clone(),
 726                                cx,
 727                            );
 728                        }
 729                        AssistantContext::Excerpt(excerpt_context) => {
 730                            log.buffer_added_as_context(
 731                                excerpt_context.context_buffer.buffer.clone(),
 732                                cx,
 733                            );
 734                        }
 735                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 736                    }
 737                }
 738            });
 739        }
 740
 741        let context_ids = new_context
 742            .iter()
 743            .map(|context| context.id())
 744            .collect::<Vec<_>>();
 745        self.context.extend(
 746            new_context
 747                .into_iter()
 748                .map(|context| (context.id(), context)),
 749        );
 750        self.context_by_message.insert(message_id, context_ids);
 751
 752        if let Some(git_checkpoint) = git_checkpoint {
 753            self.pending_checkpoint = Some(ThreadCheckpoint {
 754                message_id,
 755                git_checkpoint,
 756            });
 757        }
 758
 759        self.auto_capture_telemetry(cx);
 760
 761        message_id
 762    }
 763
 764    pub fn insert_message(
 765        &mut self,
 766        role: Role,
 767        segments: Vec<MessageSegment>,
 768        cx: &mut Context<Self>,
 769    ) -> MessageId {
 770        let id = self.next_message_id.post_inc();
 771        self.messages.push(Message {
 772            id,
 773            role,
 774            segments,
 775            context: String::new(),
 776        });
 777        self.touch_updated_at();
 778        cx.emit(ThreadEvent::MessageAdded(id));
 779        id
 780    }
 781
 782    pub fn edit_message(
 783        &mut self,
 784        id: MessageId,
 785        new_role: Role,
 786        new_segments: Vec<MessageSegment>,
 787        cx: &mut Context<Self>,
 788    ) -> bool {
 789        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 790            return false;
 791        };
 792        message.role = new_role;
 793        message.segments = new_segments;
 794        self.touch_updated_at();
 795        cx.emit(ThreadEvent::MessageEdited(id));
 796        true
 797    }
 798
 799    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 800        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 801            return false;
 802        };
 803        self.messages.remove(index);
 804        self.context_by_message.remove(&id);
 805        self.touch_updated_at();
 806        cx.emit(ThreadEvent::MessageDeleted(id));
 807        true
 808    }
 809
 810    /// Returns the representation of this [`Thread`] in a textual form.
 811    ///
 812    /// This is the representation we use when attaching a thread as context to another thread.
 813    pub fn text(&self) -> String {
 814        let mut text = String::new();
 815
 816        for message in &self.messages {
 817            text.push_str(match message.role {
 818                language_model::Role::User => "User:",
 819                language_model::Role::Assistant => "Assistant:",
 820                language_model::Role::System => "System:",
 821            });
 822            text.push('\n');
 823
 824            for segment in &message.segments {
 825                match segment {
 826                    MessageSegment::Text(content) => text.push_str(content),
 827                    MessageSegment::Thinking(content) => {
 828                        text.push_str(&format!("<think>{}</think>", content))
 829                    }
 830                }
 831            }
 832            text.push('\n');
 833        }
 834
 835        text
 836    }
 837
 838    /// Serializes this thread into a format for storage or telemetry.
 839    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 840        let initial_project_snapshot = self.initial_project_snapshot.clone();
 841        cx.spawn(async move |this, cx| {
 842            let initial_project_snapshot = initial_project_snapshot.await;
 843            this.read_with(cx, |this, cx| SerializedThread {
 844                version: SerializedThread::VERSION.to_string(),
 845                summary: this.summary_or_default(),
 846                updated_at: this.updated_at(),
 847                messages: this
 848                    .messages()
 849                    .map(|message| SerializedMessage {
 850                        id: message.id,
 851                        role: message.role,
 852                        segments: message
 853                            .segments
 854                            .iter()
 855                            .map(|segment| match segment {
 856                                MessageSegment::Text(text) => {
 857                                    SerializedMessageSegment::Text { text: text.clone() }
 858                                }
 859                                MessageSegment::Thinking(text) => {
 860                                    SerializedMessageSegment::Thinking { text: text.clone() }
 861                                }
 862                            })
 863                            .collect(),
 864                        tool_uses: this
 865                            .tool_uses_for_message(message.id, cx)
 866                            .into_iter()
 867                            .map(|tool_use| SerializedToolUse {
 868                                id: tool_use.id,
 869                                name: tool_use.name,
 870                                input: tool_use.input,
 871                            })
 872                            .collect(),
 873                        tool_results: this
 874                            .tool_results_for_message(message.id)
 875                            .into_iter()
 876                            .map(|tool_result| SerializedToolResult {
 877                                tool_use_id: tool_result.tool_use_id.clone(),
 878                                is_error: tool_result.is_error,
 879                                content: tool_result.content.clone(),
 880                            })
 881                            .collect(),
 882                        context: message.context.clone(),
 883                    })
 884                    .collect(),
 885                initial_project_snapshot,
 886                cumulative_token_usage: this.cumulative_token_usage,
 887                request_token_usage: this.request_token_usage.clone(),
 888                detailed_summary_state: this.detailed_summary_state.clone(),
 889                exceeded_window_error: this.exceeded_window_error.clone(),
 890            })
 891        })
 892    }
 893
 894    pub fn send_to_model(
 895        &mut self,
 896        model: Arc<dyn LanguageModel>,
 897        request_kind: RequestKind,
 898        cx: &mut Context<Self>,
 899    ) {
 900        let mut request = self.to_completion_request(request_kind, cx);
 901        if model.supports_tools() {
 902            request.tools = {
 903                let mut tools = Vec::new();
 904                tools.extend(
 905                    self.tools()
 906                        .read(cx)
 907                        .enabled_tools(cx)
 908                        .into_iter()
 909                        .filter_map(|tool| {
 910                            // Skip tools that cannot be supported
 911                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 912                            Some(LanguageModelRequestTool {
 913                                name: tool.name(),
 914                                description: tool.description(),
 915                                input_schema,
 916                            })
 917                        }),
 918                );
 919
 920                tools
 921            };
 922        }
 923
 924        self.stream_completion(request, model, cx);
 925    }
 926
 927    pub fn used_tools_since_last_user_message(&self) -> bool {
 928        for message in self.messages.iter().rev() {
 929            if self.tool_use.message_has_tool_results(message.id) {
 930                return true;
 931            } else if message.role == Role::User {
 932                return false;
 933            }
 934        }
 935
 936        false
 937    }
 938
 939    pub fn to_completion_request(
 940        &self,
 941        request_kind: RequestKind,
 942        cx: &App,
 943    ) -> LanguageModelRequest {
 944        let mut request = LanguageModelRequest {
 945            messages: vec![],
 946            tools: Vec::new(),
 947            stop: Vec::new(),
 948            temperature: None,
 949        };
 950
 951        if let Some(project_context) = self.project_context.borrow().as_ref() {
 952            if let Some(system_prompt) = self
 953                .prompt_builder
 954                .generate_assistant_system_prompt(project_context)
 955                .context("failed to generate assistant system prompt")
 956                .log_err()
 957            {
 958                request.messages.push(LanguageModelRequestMessage {
 959                    role: Role::System,
 960                    content: vec![MessageContent::Text(system_prompt)],
 961                    cache: true,
 962                });
 963            }
 964        } else {
 965            log::error!("project_context not set.")
 966        }
 967
 968        for message in &self.messages {
 969            let mut request_message = LanguageModelRequestMessage {
 970                role: message.role,
 971                content: Vec::new(),
 972                cache: false,
 973            };
 974
 975            match request_kind {
 976                RequestKind::Chat => {
 977                    self.tool_use
 978                        .attach_tool_results(message.id, &mut request_message);
 979                }
 980                RequestKind::Summarize => {
 981                    // We don't care about tool use during summarization.
 982                    if self.tool_use.message_has_tool_results(message.id) {
 983                        continue;
 984                    }
 985                }
 986            }
 987
 988            if !message.segments.is_empty() {
 989                request_message
 990                    .content
 991                    .push(MessageContent::Text(message.to_string()));
 992            }
 993
 994            match request_kind {
 995                RequestKind::Chat => {
 996                    self.tool_use
 997                        .attach_tool_uses(message.id, &mut request_message);
 998                }
 999                RequestKind::Summarize => {
1000                    // We don't care about tool use during summarization.
1001                }
1002            };
1003
1004            request.messages.push(request_message);
1005        }
1006
1007        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1008        if let Some(last) = request.messages.last_mut() {
1009            last.cache = true;
1010        }
1011
1012        self.attached_tracked_files_state(&mut request.messages, cx);
1013
1014        request
1015    }
1016
1017    fn attached_tracked_files_state(
1018        &self,
1019        messages: &mut Vec<LanguageModelRequestMessage>,
1020        cx: &App,
1021    ) {
1022        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1023
1024        let mut stale_message = String::new();
1025
1026        let action_log = self.action_log.read(cx);
1027
1028        for stale_file in action_log.stale_buffers(cx) {
1029            let Some(file) = stale_file.read(cx).file() else {
1030                continue;
1031            };
1032
1033            if stale_message.is_empty() {
1034                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1035            }
1036
1037            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1038        }
1039
1040        let mut content = Vec::with_capacity(2);
1041
1042        if !stale_message.is_empty() {
1043            content.push(stale_message.into());
1044        }
1045
1046        if action_log.has_edited_files_since_project_diagnostics_check() {
1047            content.push(
1048                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1049                and fix all errors AND warnings you introduced! \
1050                DO NOT mention you're going to do this until you're done."
1051                    .into(),
1052            );
1053        }
1054
1055        if !content.is_empty() {
1056            let context_message = LanguageModelRequestMessage {
1057                role: Role::User,
1058                content,
1059                cache: false,
1060            };
1061
1062            messages.push(context_message);
1063        }
1064    }
1065
1066    pub fn stream_completion(
1067        &mut self,
1068        request: LanguageModelRequest,
1069        model: Arc<dyn LanguageModel>,
1070        cx: &mut Context<Self>,
1071    ) {
1072        let pending_completion_id = post_inc(&mut self.completion_count);
1073        let task = cx.spawn(async move |thread, cx| {
1074            let stream_completion_future = model.stream_completion_with_usage(request, &cx);
1075            let initial_token_usage =
1076                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1077            let stream_completion = async {
1078                let (mut events, usage) = stream_completion_future.await?;
1079                let mut stop_reason = StopReason::EndTurn;
1080                let mut current_token_usage = TokenUsage::default();
1081
1082                if let Some(usage) = usage {
1083                    let limit = match usage.limit {
1084                        UsageLimit::Limited(limit) => limit.to_string(),
1085                        UsageLimit::Unlimited => "unlimited".to_string(),
1086                    };
1087                    log::info!("model request usage: {} / {}", usage.amount, limit);
1088                }
1089
1090                while let Some(event) = events.next().await {
1091                    let event = event?;
1092
1093                    thread.update(cx, |thread, cx| {
1094                        match event {
1095                            LanguageModelCompletionEvent::StartMessage { .. } => {
1096                                thread.insert_message(
1097                                    Role::Assistant,
1098                                    vec![MessageSegment::Text(String::new())],
1099                                    cx,
1100                                );
1101                            }
1102                            LanguageModelCompletionEvent::Stop(reason) => {
1103                                stop_reason = reason;
1104                            }
1105                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1106                                thread.update_token_usage_at_last_message(token_usage);
1107                                thread.cumulative_token_usage = thread.cumulative_token_usage
1108                                    + token_usage
1109                                    - current_token_usage;
1110                                current_token_usage = token_usage;
1111                            }
1112                            LanguageModelCompletionEvent::Text(chunk) => {
1113                                if let Some(last_message) = thread.messages.last_mut() {
1114                                    if last_message.role == Role::Assistant {
1115                                        last_message.push_text(&chunk);
1116                                        cx.emit(ThreadEvent::StreamedAssistantText(
1117                                            last_message.id,
1118                                            chunk,
1119                                        ));
1120                                    } else {
1121                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1122                                        // of a new Assistant response.
1123                                        //
1124                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1125                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1126                                        thread.insert_message(
1127                                            Role::Assistant,
1128                                            vec![MessageSegment::Text(chunk.to_string())],
1129                                            cx,
1130                                        );
1131                                    };
1132                                }
1133                            }
1134                            LanguageModelCompletionEvent::Thinking(chunk) => {
1135                                if let Some(last_message) = thread.messages.last_mut() {
1136                                    if last_message.role == Role::Assistant {
1137                                        last_message.push_thinking(&chunk);
1138                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1139                                            last_message.id,
1140                                            chunk,
1141                                        ));
1142                                    } else {
1143                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1144                                        // of a new Assistant response.
1145                                        //
1146                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1147                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1148                                        thread.insert_message(
1149                                            Role::Assistant,
1150                                            vec![MessageSegment::Thinking(chunk.to_string())],
1151                                            cx,
1152                                        );
1153                                    };
1154                                }
1155                            }
1156                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1157                                let last_assistant_message_id = thread
1158                                    .messages
1159                                    .iter_mut()
1160                                    .rfind(|message| message.role == Role::Assistant)
1161                                    .map(|message| message.id)
1162                                    .unwrap_or_else(|| {
1163                                        thread.insert_message(Role::Assistant, vec![], cx)
1164                                    });
1165
1166                                thread.tool_use.request_tool_use(
1167                                    last_assistant_message_id,
1168                                    tool_use,
1169                                    cx,
1170                                );
1171                            }
1172                        }
1173
1174                        thread.touch_updated_at();
1175                        cx.emit(ThreadEvent::StreamedCompletion);
1176                        cx.notify();
1177
1178                        thread.auto_capture_telemetry(cx);
1179                    })?;
1180
1181                    smol::future::yield_now().await;
1182                }
1183
1184                thread.update(cx, |thread, cx| {
1185                    thread
1186                        .pending_completions
1187                        .retain(|completion| completion.id != pending_completion_id);
1188
1189                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1190                        thread.summarize(cx);
1191                    }
1192                })?;
1193
1194                anyhow::Ok(stop_reason)
1195            };
1196
1197            let result = stream_completion.await;
1198
1199            thread
1200                .update(cx, |thread, cx| {
1201                    thread.finalize_pending_checkpoint(cx);
1202                    match result.as_ref() {
1203                        Ok(stop_reason) => match stop_reason {
1204                            StopReason::ToolUse => {
1205                                let tool_uses = thread.use_pending_tools(cx);
1206                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1207                            }
1208                            StopReason::EndTurn => {}
1209                            StopReason::MaxTokens => {}
1210                        },
1211                        Err(error) => {
1212                            if error.is::<PaymentRequiredError>() {
1213                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1214                            } else if error.is::<MaxMonthlySpendReachedError>() {
1215                                cx.emit(ThreadEvent::ShowError(
1216                                    ThreadError::MaxMonthlySpendReached,
1217                                ));
1218                            } else if let Some(error) =
1219                                error.downcast_ref::<ModelRequestLimitReachedError>()
1220                            {
1221                                cx.emit(ThreadEvent::ShowError(
1222                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1223                                ));
1224                            } else if let Some(known_error) =
1225                                error.downcast_ref::<LanguageModelKnownError>()
1226                            {
1227                                match known_error {
1228                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1229                                        tokens,
1230                                    } => {
1231                                        thread.exceeded_window_error = Some(ExceededWindowError {
1232                                            model_id: model.id(),
1233                                            token_count: *tokens,
1234                                        });
1235                                        cx.notify();
1236                                    }
1237                                }
1238                            } else {
1239                                let error_message = error
1240                                    .chain()
1241                                    .map(|err| err.to_string())
1242                                    .collect::<Vec<_>>()
1243                                    .join("\n");
1244                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1245                                    header: "Error interacting with language model".into(),
1246                                    message: SharedString::from(error_message.clone()),
1247                                }));
1248                            }
1249
1250                            thread.cancel_last_completion(cx);
1251                        }
1252                    }
1253                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1254
1255                    thread.auto_capture_telemetry(cx);
1256
1257                    if let Ok(initial_usage) = initial_token_usage {
1258                        let usage = thread.cumulative_token_usage - initial_usage;
1259
1260                        telemetry::event!(
1261                            "Assistant Thread Completion",
1262                            thread_id = thread.id().to_string(),
1263                            model = model.telemetry_id(),
1264                            model_provider = model.provider_id().to_string(),
1265                            input_tokens = usage.input_tokens,
1266                            output_tokens = usage.output_tokens,
1267                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1268                            cache_read_input_tokens = usage.cache_read_input_tokens,
1269                        );
1270                    }
1271                })
1272                .ok();
1273        });
1274
1275        self.pending_completions.push(PendingCompletion {
1276            id: pending_completion_id,
1277            _task: task,
1278        });
1279    }
1280
1281    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1282        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1283            return;
1284        };
1285
1286        if !model.provider.is_authenticated(cx) {
1287            return;
1288        }
1289
1290        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1291        request.messages.push(LanguageModelRequestMessage {
1292            role: Role::User,
1293            content: vec![
1294                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1295                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1296                 If the conversation is about a specific subject, include it in the title. \
1297                 Be descriptive. DO NOT speak in the first person."
1298                    .into(),
1299            ],
1300            cache: false,
1301        });
1302
1303        self.pending_summary = cx.spawn(async move |this, cx| {
1304            async move {
1305                let stream = model.model.stream_completion_text(request, &cx);
1306                let mut messages = stream.await?;
1307
1308                let mut new_summary = String::new();
1309                while let Some(message) = messages.stream.next().await {
1310                    let text = message?;
1311                    let mut lines = text.lines();
1312                    new_summary.extend(lines.next());
1313
1314                    // Stop if the LLM generated multiple lines.
1315                    if lines.next().is_some() {
1316                        break;
1317                    }
1318                }
1319
1320                this.update(cx, |this, cx| {
1321                    if !new_summary.is_empty() {
1322                        this.summary = Some(new_summary.into());
1323                    }
1324
1325                    cx.emit(ThreadEvent::SummaryGenerated);
1326                })?;
1327
1328                anyhow::Ok(())
1329            }
1330            .log_err()
1331            .await
1332        });
1333    }
1334
1335    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1336        let last_message_id = self.messages.last().map(|message| message.id)?;
1337
1338        match &self.detailed_summary_state {
1339            DetailedSummaryState::Generating { message_id, .. }
1340            | DetailedSummaryState::Generated { message_id, .. }
1341                if *message_id == last_message_id =>
1342            {
1343                // Already up-to-date
1344                return None;
1345            }
1346            _ => {}
1347        }
1348
1349        let ConfiguredModel { model, provider } =
1350            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1351
1352        if !provider.is_authenticated(cx) {
1353            return None;
1354        }
1355
1356        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1357
1358        request.messages.push(LanguageModelRequestMessage {
1359            role: Role::User,
1360            content: vec![
1361                "Generate a detailed summary of this conversation. Include:\n\
1362                1. A brief overview of what was discussed\n\
1363                2. Key facts or information discovered\n\
1364                3. Outcomes or conclusions reached\n\
1365                4. Any action items or next steps if any\n\
1366                Format it in Markdown with headings and bullet points."
1367                    .into(),
1368            ],
1369            cache: false,
1370        });
1371
1372        let task = cx.spawn(async move |thread, cx| {
1373            let stream = model.stream_completion_text(request, &cx);
1374            let Some(mut messages) = stream.await.log_err() else {
1375                thread
1376                    .update(cx, |this, _cx| {
1377                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1378                    })
1379                    .log_err();
1380
1381                return;
1382            };
1383
1384            let mut new_detailed_summary = String::new();
1385
1386            while let Some(chunk) = messages.stream.next().await {
1387                if let Some(chunk) = chunk.log_err() {
1388                    new_detailed_summary.push_str(&chunk);
1389                }
1390            }
1391
1392            thread
1393                .update(cx, |this, _cx| {
1394                    this.detailed_summary_state = DetailedSummaryState::Generated {
1395                        text: new_detailed_summary.into(),
1396                        message_id: last_message_id,
1397                    };
1398                })
1399                .log_err();
1400        });
1401
1402        self.detailed_summary_state = DetailedSummaryState::Generating {
1403            message_id: last_message_id,
1404        };
1405
1406        Some(task)
1407    }
1408
1409    pub fn is_generating_detailed_summary(&self) -> bool {
1410        matches!(
1411            self.detailed_summary_state,
1412            DetailedSummaryState::Generating { .. }
1413        )
1414    }
1415
1416    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1417        self.auto_capture_telemetry(cx);
1418        let request = self.to_completion_request(RequestKind::Chat, cx);
1419        let messages = Arc::new(request.messages);
1420        let pending_tool_uses = self
1421            .tool_use
1422            .pending_tool_uses()
1423            .into_iter()
1424            .filter(|tool_use| tool_use.status.is_idle())
1425            .cloned()
1426            .collect::<Vec<_>>();
1427
1428        for tool_use in pending_tool_uses.iter() {
1429            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1430                if tool.needs_confirmation(&tool_use.input, cx)
1431                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1432                {
1433                    self.tool_use.confirm_tool_use(
1434                        tool_use.id.clone(),
1435                        tool_use.ui_text.clone(),
1436                        tool_use.input.clone(),
1437                        messages.clone(),
1438                        tool,
1439                    );
1440                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1441                } else {
1442                    self.run_tool(
1443                        tool_use.id.clone(),
1444                        tool_use.ui_text.clone(),
1445                        tool_use.input.clone(),
1446                        &messages,
1447                        tool,
1448                        cx,
1449                    );
1450                }
1451            }
1452        }
1453
1454        pending_tool_uses
1455    }
1456
1457    pub fn run_tool(
1458        &mut self,
1459        tool_use_id: LanguageModelToolUseId,
1460        ui_text: impl Into<SharedString>,
1461        input: serde_json::Value,
1462        messages: &[LanguageModelRequestMessage],
1463        tool: Arc<dyn Tool>,
1464        cx: &mut Context<Thread>,
1465    ) {
1466        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1467        self.tool_use
1468            .run_pending_tool(tool_use_id, ui_text.into(), task);
1469    }
1470
1471    fn spawn_tool_use(
1472        &mut self,
1473        tool_use_id: LanguageModelToolUseId,
1474        messages: &[LanguageModelRequestMessage],
1475        input: serde_json::Value,
1476        tool: Arc<dyn Tool>,
1477        cx: &mut Context<Thread>,
1478    ) -> Task<()> {
1479        let tool_name: Arc<str> = tool.name().into();
1480
1481        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1482            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1483        } else {
1484            tool.run(
1485                input,
1486                messages,
1487                self.project.clone(),
1488                self.action_log.clone(),
1489                cx,
1490            )
1491        };
1492
1493        // Store the card separately if it exists
1494        if let Some(card) = tool_result.card.clone() {
1495            self.tool_use
1496                .insert_tool_result_card(tool_use_id.clone(), card);
1497        }
1498
1499        cx.spawn({
1500            async move |thread: WeakEntity<Thread>, cx| {
1501                let output = tool_result.output.await;
1502
1503                thread
1504                    .update(cx, |thread, cx| {
1505                        let pending_tool_use = thread.tool_use.insert_tool_output(
1506                            tool_use_id.clone(),
1507                            tool_name,
1508                            output,
1509                            cx,
1510                        );
1511                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1512                    })
1513                    .ok();
1514            }
1515        })
1516    }
1517
1518    fn tool_finished(
1519        &mut self,
1520        tool_use_id: LanguageModelToolUseId,
1521        pending_tool_use: Option<PendingToolUse>,
1522        canceled: bool,
1523        cx: &mut Context<Self>,
1524    ) {
1525        if self.all_tools_finished() {
1526            let model_registry = LanguageModelRegistry::read_global(cx);
1527            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1528                self.attach_tool_results(cx);
1529                if !canceled {
1530                    self.send_to_model(model, RequestKind::Chat, cx);
1531                }
1532            }
1533        }
1534
1535        cx.emit(ThreadEvent::ToolFinished {
1536            tool_use_id,
1537            pending_tool_use,
1538        });
1539    }
1540
1541    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1542        // Insert a user message to contain the tool results.
1543        self.insert_user_message(
1544            // TODO: Sending up a user message without any content results in the model sending back
1545            // responses that also don't have any content. We currently don't handle this case well,
1546            // so for now we provide some text to keep the model on track.
1547            "Here are the tool results.",
1548            Vec::new(),
1549            None,
1550            cx,
1551        );
1552    }
1553
1554    /// Cancels the last pending completion, if there are any pending.
1555    ///
1556    /// Returns whether a completion was canceled.
1557    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1558        let canceled = if self.pending_completions.pop().is_some() {
1559            true
1560        } else {
1561            let mut canceled = false;
1562            for pending_tool_use in self.tool_use.cancel_pending() {
1563                canceled = true;
1564                self.tool_finished(
1565                    pending_tool_use.id.clone(),
1566                    Some(pending_tool_use),
1567                    true,
1568                    cx,
1569                );
1570            }
1571            canceled
1572        };
1573        self.finalize_pending_checkpoint(cx);
1574        canceled
1575    }
1576
1577    pub fn feedback(&self) -> Option<ThreadFeedback> {
1578        self.feedback
1579    }
1580
1581    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1582        self.message_feedback.get(&message_id).copied()
1583    }
1584
1585    pub fn report_message_feedback(
1586        &mut self,
1587        message_id: MessageId,
1588        feedback: ThreadFeedback,
1589        cx: &mut Context<Self>,
1590    ) -> Task<Result<()>> {
1591        if self.message_feedback.get(&message_id) == Some(&feedback) {
1592            return Task::ready(Ok(()));
1593        }
1594
1595        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1596        let serialized_thread = self.serialize(cx);
1597        let thread_id = self.id().clone();
1598        let client = self.project.read(cx).client();
1599
1600        let enabled_tool_names: Vec<String> = self
1601            .tools()
1602            .read(cx)
1603            .enabled_tools(cx)
1604            .iter()
1605            .map(|tool| tool.name().to_string())
1606            .collect();
1607
1608        self.message_feedback.insert(message_id, feedback);
1609
1610        cx.notify();
1611
1612        let message_content = self
1613            .message(message_id)
1614            .map(|msg| msg.to_string())
1615            .unwrap_or_default();
1616
1617        cx.background_spawn(async move {
1618            let final_project_snapshot = final_project_snapshot.await;
1619            let serialized_thread = serialized_thread.await?;
1620            let thread_data =
1621                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1622
1623            let rating = match feedback {
1624                ThreadFeedback::Positive => "positive",
1625                ThreadFeedback::Negative => "negative",
1626            };
1627            telemetry::event!(
1628                "Assistant Thread Rated",
1629                rating,
1630                thread_id,
1631                enabled_tool_names,
1632                message_id = message_id.0,
1633                message_content,
1634                thread_data,
1635                final_project_snapshot
1636            );
1637            client.telemetry().flush_events();
1638
1639            Ok(())
1640        })
1641    }
1642
1643    pub fn report_feedback(
1644        &mut self,
1645        feedback: ThreadFeedback,
1646        cx: &mut Context<Self>,
1647    ) -> Task<Result<()>> {
1648        let last_assistant_message_id = self
1649            .messages
1650            .iter()
1651            .rev()
1652            .find(|msg| msg.role == Role::Assistant)
1653            .map(|msg| msg.id);
1654
1655        if let Some(message_id) = last_assistant_message_id {
1656            self.report_message_feedback(message_id, feedback, cx)
1657        } else {
1658            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1659            let serialized_thread = self.serialize(cx);
1660            let thread_id = self.id().clone();
1661            let client = self.project.read(cx).client();
1662            self.feedback = Some(feedback);
1663            cx.notify();
1664
1665            cx.background_spawn(async move {
1666                let final_project_snapshot = final_project_snapshot.await;
1667                let serialized_thread = serialized_thread.await?;
1668                let thread_data = serde_json::to_value(serialized_thread)
1669                    .unwrap_or_else(|_| serde_json::Value::Null);
1670
1671                let rating = match feedback {
1672                    ThreadFeedback::Positive => "positive",
1673                    ThreadFeedback::Negative => "negative",
1674                };
1675                telemetry::event!(
1676                    "Assistant Thread Rated",
1677                    rating,
1678                    thread_id,
1679                    thread_data,
1680                    final_project_snapshot
1681                );
1682                client.telemetry().flush_events();
1683
1684                Ok(())
1685            })
1686        }
1687    }
1688
1689    /// Create a snapshot of the current project state including git information and unsaved buffers.
1690    fn project_snapshot(
1691        project: Entity<Project>,
1692        cx: &mut Context<Self>,
1693    ) -> Task<Arc<ProjectSnapshot>> {
1694        let git_store = project.read(cx).git_store().clone();
1695        let worktree_snapshots: Vec<_> = project
1696            .read(cx)
1697            .visible_worktrees(cx)
1698            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1699            .collect();
1700
1701        cx.spawn(async move |_, cx| {
1702            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1703
1704            let mut unsaved_buffers = Vec::new();
1705            cx.update(|app_cx| {
1706                let buffer_store = project.read(app_cx).buffer_store();
1707                for buffer_handle in buffer_store.read(app_cx).buffers() {
1708                    let buffer = buffer_handle.read(app_cx);
1709                    if buffer.is_dirty() {
1710                        if let Some(file) = buffer.file() {
1711                            let path = file.path().to_string_lossy().to_string();
1712                            unsaved_buffers.push(path);
1713                        }
1714                    }
1715                }
1716            })
1717            .ok();
1718
1719            Arc::new(ProjectSnapshot {
1720                worktree_snapshots,
1721                unsaved_buffer_paths: unsaved_buffers,
1722                timestamp: Utc::now(),
1723            })
1724        })
1725    }
1726
1727    fn worktree_snapshot(
1728        worktree: Entity<project::Worktree>,
1729        git_store: Entity<GitStore>,
1730        cx: &App,
1731    ) -> Task<WorktreeSnapshot> {
1732        cx.spawn(async move |cx| {
1733            // Get worktree path and snapshot
1734            let worktree_info = cx.update(|app_cx| {
1735                let worktree = worktree.read(app_cx);
1736                let path = worktree.abs_path().to_string_lossy().to_string();
1737                let snapshot = worktree.snapshot();
1738                (path, snapshot)
1739            });
1740
1741            let Ok((worktree_path, _snapshot)) = worktree_info else {
1742                return WorktreeSnapshot {
1743                    worktree_path: String::new(),
1744                    git_state: None,
1745                };
1746            };
1747
1748            let git_state = git_store
1749                .update(cx, |git_store, cx| {
1750                    git_store
1751                        .repositories()
1752                        .values()
1753                        .find(|repo| {
1754                            repo.read(cx)
1755                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1756                                .is_some()
1757                        })
1758                        .cloned()
1759                })
1760                .ok()
1761                .flatten()
1762                .map(|repo| {
1763                    repo.update(cx, |repo, _| {
1764                        let current_branch =
1765                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1766                        repo.send_job(None, |state, _| async move {
1767                            let RepositoryState::Local { backend, .. } = state else {
1768                                return GitState {
1769                                    remote_url: None,
1770                                    head_sha: None,
1771                                    current_branch,
1772                                    diff: None,
1773                                };
1774                            };
1775
1776                            let remote_url = backend.remote_url("origin");
1777                            let head_sha = backend.head_sha();
1778                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1779
1780                            GitState {
1781                                remote_url,
1782                                head_sha,
1783                                current_branch,
1784                                diff,
1785                            }
1786                        })
1787                    })
1788                });
1789
1790            let git_state = match git_state {
1791                Some(git_state) => match git_state.ok() {
1792                    Some(git_state) => git_state.await.ok(),
1793                    None => None,
1794                },
1795                None => None,
1796            };
1797
1798            WorktreeSnapshot {
1799                worktree_path,
1800                git_state,
1801            }
1802        })
1803    }
1804
1805    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1806        let mut markdown = Vec::new();
1807
1808        if let Some(summary) = self.summary() {
1809            writeln!(markdown, "# {summary}\n")?;
1810        };
1811
1812        for message in self.messages() {
1813            writeln!(
1814                markdown,
1815                "## {role}\n",
1816                role = match message.role {
1817                    Role::User => "User",
1818                    Role::Assistant => "Assistant",
1819                    Role::System => "System",
1820                }
1821            )?;
1822
1823            if !message.context.is_empty() {
1824                writeln!(markdown, "{}", message.context)?;
1825            }
1826
1827            for segment in &message.segments {
1828                match segment {
1829                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1830                    MessageSegment::Thinking(text) => {
1831                        writeln!(markdown, "<think>{}</think>\n", text)?
1832                    }
1833                }
1834            }
1835
1836            for tool_use in self.tool_uses_for_message(message.id, cx) {
1837                writeln!(
1838                    markdown,
1839                    "**Use Tool: {} ({})**",
1840                    tool_use.name, tool_use.id
1841                )?;
1842                writeln!(markdown, "```json")?;
1843                writeln!(
1844                    markdown,
1845                    "{}",
1846                    serde_json::to_string_pretty(&tool_use.input)?
1847                )?;
1848                writeln!(markdown, "```")?;
1849            }
1850
1851            for tool_result in self.tool_results_for_message(message.id) {
1852                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1853                if tool_result.is_error {
1854                    write!(markdown, " (Error)")?;
1855                }
1856
1857                writeln!(markdown, "**\n")?;
1858                writeln!(markdown, "{}", tool_result.content)?;
1859            }
1860        }
1861
1862        Ok(String::from_utf8_lossy(&markdown).to_string())
1863    }
1864
1865    pub fn keep_edits_in_range(
1866        &mut self,
1867        buffer: Entity<language::Buffer>,
1868        buffer_range: Range<language::Anchor>,
1869        cx: &mut Context<Self>,
1870    ) {
1871        self.action_log.update(cx, |action_log, cx| {
1872            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1873        });
1874    }
1875
1876    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1877        self.action_log
1878            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1879    }
1880
1881    pub fn reject_edits_in_ranges(
1882        &mut self,
1883        buffer: Entity<language::Buffer>,
1884        buffer_ranges: Vec<Range<language::Anchor>>,
1885        cx: &mut Context<Self>,
1886    ) -> Task<Result<()>> {
1887        self.action_log.update(cx, |action_log, cx| {
1888            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1889        })
1890    }
1891
1892    pub fn action_log(&self) -> &Entity<ActionLog> {
1893        &self.action_log
1894    }
1895
1896    pub fn project(&self) -> &Entity<Project> {
1897        &self.project
1898    }
1899
1900    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1901        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1902            return;
1903        }
1904
1905        let now = Instant::now();
1906        if let Some(last) = self.last_auto_capture_at {
1907            if now.duration_since(last).as_secs() < 10 {
1908                return;
1909            }
1910        }
1911
1912        self.last_auto_capture_at = Some(now);
1913
1914        let thread_id = self.id().clone();
1915        let github_login = self
1916            .project
1917            .read(cx)
1918            .user_store()
1919            .read(cx)
1920            .current_user()
1921            .map(|user| user.github_login.clone());
1922        let client = self.project.read(cx).client().clone();
1923        let serialize_task = self.serialize(cx);
1924
1925        cx.background_executor()
1926            .spawn(async move {
1927                if let Ok(serialized_thread) = serialize_task.await {
1928                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1929                        telemetry::event!(
1930                            "Agent Thread Auto-Captured",
1931                            thread_id = thread_id.to_string(),
1932                            thread_data = thread_data,
1933                            auto_capture_reason = "tracked_user",
1934                            github_login = github_login
1935                        );
1936
1937                        client.telemetry().flush_events();
1938                    }
1939                }
1940            })
1941            .detach();
1942    }
1943
1944    pub fn cumulative_token_usage(&self) -> TokenUsage {
1945        self.cumulative_token_usage
1946    }
1947
1948    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1949        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1950            return TotalTokenUsage::default();
1951        };
1952
1953        let max = model.model.max_token_count();
1954
1955        let index = self
1956            .messages
1957            .iter()
1958            .position(|msg| msg.id == message_id)
1959            .unwrap_or(0);
1960
1961        if index == 0 {
1962            return TotalTokenUsage { total: 0, max };
1963        }
1964
1965        let token_usage = &self
1966            .request_token_usage
1967            .get(index - 1)
1968            .cloned()
1969            .unwrap_or_default();
1970
1971        TotalTokenUsage {
1972            total: token_usage.total_tokens() as usize,
1973            max,
1974        }
1975    }
1976
1977    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1978        let model_registry = LanguageModelRegistry::read_global(cx);
1979        let Some(model) = model_registry.default_model() else {
1980            return TotalTokenUsage::default();
1981        };
1982
1983        let max = model.model.max_token_count();
1984
1985        if let Some(exceeded_error) = &self.exceeded_window_error {
1986            if model.model.id() == exceeded_error.model_id {
1987                return TotalTokenUsage {
1988                    total: exceeded_error.token_count,
1989                    max,
1990                };
1991            }
1992        }
1993
1994        let total = self
1995            .token_usage_at_last_message()
1996            .unwrap_or_default()
1997            .total_tokens() as usize;
1998
1999        TotalTokenUsage { total, max }
2000    }
2001
2002    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
2003        self.request_token_usage
2004            .get(self.messages.len().saturating_sub(1))
2005            .or_else(|| self.request_token_usage.last())
2006            .cloned()
2007    }
2008
2009    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
2010        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
2011        self.request_token_usage
2012            .resize(self.messages.len(), placeholder);
2013
2014        if let Some(last) = self.request_token_usage.last_mut() {
2015            *last = token_usage;
2016        }
2017    }
2018
2019    pub fn deny_tool_use(
2020        &mut self,
2021        tool_use_id: LanguageModelToolUseId,
2022        tool_name: Arc<str>,
2023        cx: &mut Context<Self>,
2024    ) {
2025        let err = Err(anyhow::anyhow!(
2026            "Permission to run tool action denied by user"
2027        ));
2028
2029        self.tool_use
2030            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2031        self.tool_finished(tool_use_id.clone(), None, true, cx);
2032    }
2033}
2034
2035#[derive(Debug, Clone, Error)]
2036pub enum ThreadError {
2037    #[error("Payment required")]
2038    PaymentRequired,
2039    #[error("Max monthly spend reached")]
2040    MaxMonthlySpendReached,
2041    #[error("Model request limit reached")]
2042    ModelRequestLimitReached { plan: Plan },
2043    #[error("Message {header}: {message}")]
2044    Message {
2045        header: SharedString,
2046        message: SharedString,
2047    },
2048}
2049
2050#[derive(Debug, Clone)]
2051pub enum ThreadEvent {
2052    ShowError(ThreadError),
2053    StreamedCompletion,
2054    StreamedAssistantText(MessageId, String),
2055    StreamedAssistantThinking(MessageId, String),
2056    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2057    MessageAdded(MessageId),
2058    MessageEdited(MessageId),
2059    MessageDeleted(MessageId),
2060    SummaryGenerated,
2061    SummaryChanged,
2062    UsePendingTools {
2063        tool_uses: Vec<PendingToolUse>,
2064    },
2065    ToolFinished {
2066        #[allow(unused)]
2067        tool_use_id: LanguageModelToolUseId,
2068        /// The pending tool use that corresponds to this tool.
2069        pending_tool_use: Option<PendingToolUse>,
2070    },
2071    CheckpointChanged,
2072    ToolConfirmationNeeded,
2073}
2074
2075impl EventEmitter<ThreadEvent> for Thread {}
2076
2077struct PendingCompletion {
2078    id: usize,
2079    _task: Task<()>,
2080}
2081
2082#[cfg(test)]
2083mod tests {
2084    use super::*;
2085    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2086    use assistant_settings::AssistantSettings;
2087    use context_server::ContextServerSettings;
2088    use editor::EditorSettings;
2089    use gpui::TestAppContext;
2090    use project::{FakeFs, Project};
2091    use prompt_store::PromptBuilder;
2092    use serde_json::json;
2093    use settings::{Settings, SettingsStore};
2094    use std::sync::Arc;
2095    use theme::ThemeSettings;
2096    use util::path;
2097    use workspace::Workspace;
2098
2099    #[gpui::test]
2100    async fn test_message_with_context(cx: &mut TestAppContext) {
2101        init_test_settings(cx);
2102
2103        let project = create_test_project(
2104            cx,
2105            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2106        )
2107        .await;
2108
2109        let (_workspace, _thread_store, thread, context_store) =
2110            setup_test_environment(cx, project.clone()).await;
2111
2112        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2113            .await
2114            .unwrap();
2115
2116        let context =
2117            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2118
2119        // Insert user message with context
2120        let message_id = thread.update(cx, |thread, cx| {
2121            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2122        });
2123
2124        // Check content and context in message object
2125        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2126
2127        // Use different path format strings based on platform for the test
2128        #[cfg(windows)]
2129        let path_part = r"test\code.rs";
2130        #[cfg(not(windows))]
2131        let path_part = "test/code.rs";
2132
2133        let expected_context = format!(
2134            r#"
2135<context>
2136The following items were attached by the user. You don't need to use other tools to read them.
2137
2138<files>
2139```rs {path_part}
2140fn main() {{
2141    println!("Hello, world!");
2142}}
2143```
2144</files>
2145</context>
2146"#
2147        );
2148
2149        assert_eq!(message.role, Role::User);
2150        assert_eq!(message.segments.len(), 1);
2151        assert_eq!(
2152            message.segments[0],
2153            MessageSegment::Text("Please explain this code".to_string())
2154        );
2155        assert_eq!(message.context, expected_context);
2156
2157        // Check message in request
2158        let request = thread.read_with(cx, |thread, cx| {
2159            thread.to_completion_request(RequestKind::Chat, cx)
2160        });
2161
2162        assert_eq!(request.messages.len(), 2);
2163        let expected_full_message = format!("{}Please explain this code", expected_context);
2164        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2165    }
2166
2167    #[gpui::test]
2168    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2169        init_test_settings(cx);
2170
2171        let project = create_test_project(
2172            cx,
2173            json!({
2174                "file1.rs": "fn function1() {}\n",
2175                "file2.rs": "fn function2() {}\n",
2176                "file3.rs": "fn function3() {}\n",
2177            }),
2178        )
2179        .await;
2180
2181        let (_, _thread_store, thread, context_store) =
2182            setup_test_environment(cx, project.clone()).await;
2183
2184        // Open files individually
2185        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2186            .await
2187            .unwrap();
2188        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2189            .await
2190            .unwrap();
2191        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2192            .await
2193            .unwrap();
2194
2195        // Get the context objects
2196        let contexts = context_store.update(cx, |store, _| store.context().clone());
2197        assert_eq!(contexts.len(), 3);
2198
2199        // First message with context 1
2200        let message1_id = thread.update(cx, |thread, cx| {
2201            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2202        });
2203
2204        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2205        let message2_id = thread.update(cx, |thread, cx| {
2206            thread.insert_user_message(
2207                "Message 2",
2208                vec![contexts[0].clone(), contexts[1].clone()],
2209                None,
2210                cx,
2211            )
2212        });
2213
2214        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2215        let message3_id = thread.update(cx, |thread, cx| {
2216            thread.insert_user_message(
2217                "Message 3",
2218                vec![
2219                    contexts[0].clone(),
2220                    contexts[1].clone(),
2221                    contexts[2].clone(),
2222                ],
2223                None,
2224                cx,
2225            )
2226        });
2227
2228        // Check what contexts are included in each message
2229        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2230            (
2231                thread.message(message1_id).unwrap().clone(),
2232                thread.message(message2_id).unwrap().clone(),
2233                thread.message(message3_id).unwrap().clone(),
2234            )
2235        });
2236
2237        // First message should include context 1
2238        assert!(message1.context.contains("file1.rs"));
2239
2240        // Second message should include only context 2 (not 1)
2241        assert!(!message2.context.contains("file1.rs"));
2242        assert!(message2.context.contains("file2.rs"));
2243
2244        // Third message should include only context 3 (not 1 or 2)
2245        assert!(!message3.context.contains("file1.rs"));
2246        assert!(!message3.context.contains("file2.rs"));
2247        assert!(message3.context.contains("file3.rs"));
2248
2249        // Check entire request to make sure all contexts are properly included
2250        let request = thread.read_with(cx, |thread, cx| {
2251            thread.to_completion_request(RequestKind::Chat, cx)
2252        });
2253
2254        // The request should contain all 3 messages
2255        assert_eq!(request.messages.len(), 4);
2256
2257        // Check that the contexts are properly formatted in each message
2258        assert!(request.messages[1].string_contents().contains("file1.rs"));
2259        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2260        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2261
2262        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2263        assert!(request.messages[2].string_contents().contains("file2.rs"));
2264        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2265
2266        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2267        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2268        assert!(request.messages[3].string_contents().contains("file3.rs"));
2269    }
2270
2271    #[gpui::test]
2272    async fn test_message_without_files(cx: &mut TestAppContext) {
2273        init_test_settings(cx);
2274
2275        let project = create_test_project(
2276            cx,
2277            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2278        )
2279        .await;
2280
2281        let (_, _thread_store, thread, _context_store) =
2282            setup_test_environment(cx, project.clone()).await;
2283
2284        // Insert user message without any context (empty context vector)
2285        let message_id = thread.update(cx, |thread, cx| {
2286            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2287        });
2288
2289        // Check content and context in message object
2290        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2291
2292        // Context should be empty when no files are included
2293        assert_eq!(message.role, Role::User);
2294        assert_eq!(message.segments.len(), 1);
2295        assert_eq!(
2296            message.segments[0],
2297            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2298        );
2299        assert_eq!(message.context, "");
2300
2301        // Check message in request
2302        let request = thread.read_with(cx, |thread, cx| {
2303            thread.to_completion_request(RequestKind::Chat, cx)
2304        });
2305
2306        assert_eq!(request.messages.len(), 2);
2307        assert_eq!(
2308            request.messages[1].string_contents(),
2309            "What is the best way to learn Rust?"
2310        );
2311
2312        // Add second message, also without context
2313        let message2_id = thread.update(cx, |thread, cx| {
2314            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2315        });
2316
2317        let message2 =
2318            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2319        assert_eq!(message2.context, "");
2320
2321        // Check that both messages appear in the request
2322        let request = thread.read_with(cx, |thread, cx| {
2323            thread.to_completion_request(RequestKind::Chat, cx)
2324        });
2325
2326        assert_eq!(request.messages.len(), 3);
2327        assert_eq!(
2328            request.messages[1].string_contents(),
2329            "What is the best way to learn Rust?"
2330        );
2331        assert_eq!(
2332            request.messages[2].string_contents(),
2333            "Are there any good books?"
2334        );
2335    }
2336
2337    #[gpui::test]
2338    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2339        init_test_settings(cx);
2340
2341        let project = create_test_project(
2342            cx,
2343            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2344        )
2345        .await;
2346
2347        let (_workspace, _thread_store, thread, context_store) =
2348            setup_test_environment(cx, project.clone()).await;
2349
2350        // Open buffer and add it to context
2351        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2352            .await
2353            .unwrap();
2354
2355        let context =
2356            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2357
2358        // Insert user message with the buffer as context
2359        thread.update(cx, |thread, cx| {
2360            thread.insert_user_message("Explain this code", vec![context], None, cx)
2361        });
2362
2363        // Create a request and check that it doesn't have a stale buffer warning yet
2364        let initial_request = thread.read_with(cx, |thread, cx| {
2365            thread.to_completion_request(RequestKind::Chat, cx)
2366        });
2367
2368        // Make sure we don't have a stale file warning yet
2369        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2370            msg.string_contents()
2371                .contains("These files changed since last read:")
2372        });
2373        assert!(
2374            !has_stale_warning,
2375            "Should not have stale buffer warning before buffer is modified"
2376        );
2377
2378        // Modify the buffer
2379        buffer.update(cx, |buffer, cx| {
2380            // Find a position at the end of line 1
2381            buffer.edit(
2382                [(1..1, "\n    println!(\"Added a new line\");\n")],
2383                None,
2384                cx,
2385            );
2386        });
2387
2388        // Insert another user message without context
2389        thread.update(cx, |thread, cx| {
2390            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2391        });
2392
2393        // Create a new request and check for the stale buffer warning
2394        let new_request = thread.read_with(cx, |thread, cx| {
2395            thread.to_completion_request(RequestKind::Chat, cx)
2396        });
2397
2398        // We should have a stale file warning as the last message
2399        let last_message = new_request
2400            .messages
2401            .last()
2402            .expect("Request should have messages");
2403
2404        // The last message should be the stale buffer notification
2405        assert_eq!(last_message.role, Role::User);
2406
2407        // Check the exact content of the message
2408        let expected_content = "These files changed since last read:\n- code.rs\n";
2409        assert_eq!(
2410            last_message.string_contents(),
2411            expected_content,
2412            "Last message should be exactly the stale buffer notification"
2413        );
2414    }
2415
2416    fn init_test_settings(cx: &mut TestAppContext) {
2417        cx.update(|cx| {
2418            let settings_store = SettingsStore::test(cx);
2419            cx.set_global(settings_store);
2420            language::init(cx);
2421            Project::init_settings(cx);
2422            AssistantSettings::register(cx);
2423            thread_store::init(cx);
2424            workspace::init_settings(cx);
2425            ThemeSettings::register(cx);
2426            ContextServerSettings::register(cx);
2427            EditorSettings::register(cx);
2428        });
2429    }
2430
2431    // Helper to create a test project with test files
2432    async fn create_test_project(
2433        cx: &mut TestAppContext,
2434        files: serde_json::Value,
2435    ) -> Entity<Project> {
2436        let fs = FakeFs::new(cx.executor());
2437        fs.insert_tree(path!("/test"), files).await;
2438        Project::test(fs, [path!("/test").as_ref()], cx).await
2439    }
2440
2441    async fn setup_test_environment(
2442        cx: &mut TestAppContext,
2443        project: Entity<Project>,
2444    ) -> (
2445        Entity<Workspace>,
2446        Entity<ThreadStore>,
2447        Entity<Thread>,
2448        Entity<ContextStore>,
2449    ) {
2450        let (workspace, cx) =
2451            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2452
2453        let thread_store = cx
2454            .update(|_, cx| {
2455                ThreadStore::load(
2456                    project.clone(),
2457                    cx.new(|_| ToolWorkingSet::default()),
2458                    Arc::new(PromptBuilder::new(None).unwrap()),
2459                    cx,
2460                )
2461            })
2462            .await;
2463
2464        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2465        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2466
2467        (workspace, thread_store, thread, context_store)
2468    }
2469
2470    async fn add_file_to_context(
2471        project: &Entity<Project>,
2472        context_store: &Entity<ContextStore>,
2473        path: &str,
2474        cx: &mut TestAppContext,
2475    ) -> Result<Entity<language::Buffer>> {
2476        let buffer_path = project
2477            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2478            .unwrap();
2479
2480        let buffer = project
2481            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2482            .await
2483            .unwrap();
2484
2485        context_store
2486            .update(cx, |store, cx| {
2487                store.add_file_from_buffer(buffer.clone(), cx)
2488            })
2489            .await?;
2490
2491        Ok(buffer)
2492    }
2493}