thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
  23};
  24use project::Project;
  25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  26use prompt_store::PromptBuilder;
  27use proto::Plan;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings;
  31use thiserror::Error;
  32use util::{ResultExt as _, TryFutureExt as _, post_inc};
  33use uuid::Uuid;
  34
  35use crate::context::{AssistantContext, ContextId, format_context_as_string};
  36use crate::thread_store::{
  37    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  38    SerializedToolUse, SharedProjectContext,
  39};
  40use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  41
  42#[derive(Debug, Clone, Copy)]
  43pub enum RequestKind {
  44    Chat,
  45    /// Used when summarizing a thread.
  46    Summarize,
  47}
  48
  49#[derive(
  50    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  51)]
  52pub struct ThreadId(Arc<str>);
  53
  54impl ThreadId {
  55    pub fn new() -> Self {
  56        Self(Uuid::new_v4().to_string().into())
  57    }
  58}
  59
  60impl std::fmt::Display for ThreadId {
  61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  62        write!(f, "{}", self.0)
  63    }
  64}
  65
  66impl From<&str> for ThreadId {
  67    fn from(value: &str) -> Self {
  68        Self(value.into())
  69    }
  70}
  71
  72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  73pub struct MessageId(pub(crate) usize);
  74
  75impl MessageId {
  76    fn post_inc(&mut self) -> Self {
  77        Self(post_inc(&mut self.0))
  78    }
  79}
  80
  81/// A message in a [`Thread`].
  82#[derive(Debug, Clone)]
  83pub struct Message {
  84    pub id: MessageId,
  85    pub role: Role,
  86    pub segments: Vec<MessageSegment>,
  87    pub context: String,
  88}
  89
  90impl Message {
  91    /// Returns whether the message contains any meaningful text that should be displayed
  92    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  93    pub fn should_display_content(&self) -> bool {
  94        self.segments.iter().all(|segment| segment.should_display())
  95    }
  96
  97    pub fn push_thinking(&mut self, text: &str) {
  98        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  99            segment.push_str(text);
 100        } else {
 101            self.segments
 102                .push(MessageSegment::Thinking(text.to_string()));
 103        }
 104    }
 105
 106    pub fn push_text(&mut self, text: &str) {
 107        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 108            segment.push_str(text);
 109        } else {
 110            self.segments.push(MessageSegment::Text(text.to_string()));
 111        }
 112    }
 113
 114    pub fn to_string(&self) -> String {
 115        let mut result = String::new();
 116
 117        if !self.context.is_empty() {
 118            result.push_str(&self.context);
 119        }
 120
 121        for segment in &self.segments {
 122            match segment {
 123                MessageSegment::Text(text) => result.push_str(text),
 124                MessageSegment::Thinking(text) => {
 125                    result.push_str("<think>");
 126                    result.push_str(text);
 127                    result.push_str("</think>");
 128                }
 129            }
 130        }
 131
 132        result
 133    }
 134}
 135
 136#[derive(Debug, Clone, PartialEq, Eq)]
 137pub enum MessageSegment {
 138    Text(String),
 139    Thinking(String),
 140}
 141
 142impl MessageSegment {
 143    pub fn text_mut(&mut self) -> &mut String {
 144        match self {
 145            Self::Text(text) => text,
 146            Self::Thinking(text) => text,
 147        }
 148    }
 149
 150    pub fn should_display(&self) -> bool {
 151        // We add USING_TOOL_MARKER when making a request that includes tool uses
 152        // without non-whitespace text around them, and this can cause the model
 153        // to mimic the pattern, so we consider those segments not displayable.
 154        match self {
 155            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 156            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 157        }
 158    }
 159}
 160
 161#[derive(Debug, Clone, Serialize, Deserialize)]
 162pub struct ProjectSnapshot {
 163    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 164    pub unsaved_buffer_paths: Vec<String>,
 165    pub timestamp: DateTime<Utc>,
 166}
 167
 168#[derive(Debug, Clone, Serialize, Deserialize)]
 169pub struct WorktreeSnapshot {
 170    pub worktree_path: String,
 171    pub git_state: Option<GitState>,
 172}
 173
 174#[derive(Debug, Clone, Serialize, Deserialize)]
 175pub struct GitState {
 176    pub remote_url: Option<String>,
 177    pub head_sha: Option<String>,
 178    pub current_branch: Option<String>,
 179    pub diff: Option<String>,
 180}
 181
 182#[derive(Clone)]
 183pub struct ThreadCheckpoint {
 184    message_id: MessageId,
 185    git_checkpoint: GitStoreCheckpoint,
 186}
 187
 188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 189pub enum ThreadFeedback {
 190    Positive,
 191    Negative,
 192}
 193
 194pub enum LastRestoreCheckpoint {
 195    Pending {
 196        message_id: MessageId,
 197    },
 198    Error {
 199        message_id: MessageId,
 200        error: String,
 201    },
 202}
 203
 204impl LastRestoreCheckpoint {
 205    pub fn message_id(&self) -> MessageId {
 206        match self {
 207            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 208            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 209        }
 210    }
 211}
 212
 213#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 214pub enum DetailedSummaryState {
 215    #[default]
 216    NotGenerated,
 217    Generating {
 218        message_id: MessageId,
 219    },
 220    Generated {
 221        text: SharedString,
 222        message_id: MessageId,
 223    },
 224}
 225
 226#[derive(Default)]
 227pub struct TotalTokenUsage {
 228    pub total: usize,
 229    pub max: usize,
 230    pub ratio: TokenUsageRatio,
 231}
 232
 233#[derive(Debug, Default, PartialEq, Eq)]
 234pub enum TokenUsageRatio {
 235    #[default]
 236    Normal,
 237    Warning,
 238    Exceeded,
 239}
 240
 241/// A thread of conversation with the LLM.
 242pub struct Thread {
 243    id: ThreadId,
 244    updated_at: DateTime<Utc>,
 245    summary: Option<SharedString>,
 246    pending_summary: Task<Option<()>>,
 247    detailed_summary_state: DetailedSummaryState,
 248    messages: Vec<Message>,
 249    next_message_id: MessageId,
 250    context: BTreeMap<ContextId, AssistantContext>,
 251    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 252    project_context: SharedProjectContext,
 253    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 254    completion_count: usize,
 255    pending_completions: Vec<PendingCompletion>,
 256    project: Entity<Project>,
 257    prompt_builder: Arc<PromptBuilder>,
 258    tools: Entity<ToolWorkingSet>,
 259    tool_use: ToolUseState,
 260    action_log: Entity<ActionLog>,
 261    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 262    pending_checkpoint: Option<ThreadCheckpoint>,
 263    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 264    cumulative_token_usage: TokenUsage,
 265    exceeded_window_error: Option<ExceededWindowError>,
 266    feedback: Option<ThreadFeedback>,
 267    message_feedback: HashMap<MessageId, ThreadFeedback>,
 268    last_auto_capture_at: Option<Instant>,
 269}
 270
 271#[derive(Debug, Clone, Serialize, Deserialize)]
 272pub struct ExceededWindowError {
 273    /// Model used when last message exceeded context window
 274    model_id: LanguageModelId,
 275    /// Token count including last message
 276    token_count: usize,
 277}
 278
 279impl Thread {
 280    pub fn new(
 281        project: Entity<Project>,
 282        tools: Entity<ToolWorkingSet>,
 283        prompt_builder: Arc<PromptBuilder>,
 284        system_prompt: SharedProjectContext,
 285        cx: &mut Context<Self>,
 286    ) -> Self {
 287        Self {
 288            id: ThreadId::new(),
 289            updated_at: Utc::now(),
 290            summary: None,
 291            pending_summary: Task::ready(None),
 292            detailed_summary_state: DetailedSummaryState::NotGenerated,
 293            messages: Vec::new(),
 294            next_message_id: MessageId(0),
 295            context: BTreeMap::default(),
 296            context_by_message: HashMap::default(),
 297            project_context: system_prompt,
 298            checkpoints_by_message: HashMap::default(),
 299            completion_count: 0,
 300            pending_completions: Vec::new(),
 301            project: project.clone(),
 302            prompt_builder,
 303            tools: tools.clone(),
 304            last_restore_checkpoint: None,
 305            pending_checkpoint: None,
 306            tool_use: ToolUseState::new(tools.clone()),
 307            action_log: cx.new(|_| ActionLog::new(project.clone())),
 308            initial_project_snapshot: {
 309                let project_snapshot = Self::project_snapshot(project, cx);
 310                cx.foreground_executor()
 311                    .spawn(async move { Some(project_snapshot.await) })
 312                    .shared()
 313            },
 314            cumulative_token_usage: TokenUsage::default(),
 315            exceeded_window_error: None,
 316            feedback: None,
 317            message_feedback: HashMap::default(),
 318            last_auto_capture_at: None,
 319        }
 320    }
 321
 322    pub fn deserialize(
 323        id: ThreadId,
 324        serialized: SerializedThread,
 325        project: Entity<Project>,
 326        tools: Entity<ToolWorkingSet>,
 327        prompt_builder: Arc<PromptBuilder>,
 328        project_context: SharedProjectContext,
 329        cx: &mut Context<Self>,
 330    ) -> Self {
 331        let next_message_id = MessageId(
 332            serialized
 333                .messages
 334                .last()
 335                .map(|message| message.id.0 + 1)
 336                .unwrap_or(0),
 337        );
 338        let tool_use =
 339            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 340
 341        Self {
 342            id,
 343            updated_at: serialized.updated_at,
 344            summary: Some(serialized.summary),
 345            pending_summary: Task::ready(None),
 346            detailed_summary_state: serialized.detailed_summary_state,
 347            messages: serialized
 348                .messages
 349                .into_iter()
 350                .map(|message| Message {
 351                    id: message.id,
 352                    role: message.role,
 353                    segments: message
 354                        .segments
 355                        .into_iter()
 356                        .map(|segment| match segment {
 357                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 358                            SerializedMessageSegment::Thinking { text } => {
 359                                MessageSegment::Thinking(text)
 360                            }
 361                        })
 362                        .collect(),
 363                    context: message.context,
 364                })
 365                .collect(),
 366            next_message_id,
 367            context: BTreeMap::default(),
 368            context_by_message: HashMap::default(),
 369            project_context,
 370            checkpoints_by_message: HashMap::default(),
 371            completion_count: 0,
 372            pending_completions: Vec::new(),
 373            last_restore_checkpoint: None,
 374            pending_checkpoint: None,
 375            project: project.clone(),
 376            prompt_builder,
 377            tools,
 378            tool_use,
 379            action_log: cx.new(|_| ActionLog::new(project)),
 380            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 381            cumulative_token_usage: serialized.cumulative_token_usage,
 382            exceeded_window_error: None,
 383            feedback: None,
 384            message_feedback: HashMap::default(),
 385            last_auto_capture_at: None,
 386        }
 387    }
 388
 389    pub fn id(&self) -> &ThreadId {
 390        &self.id
 391    }
 392
 393    pub fn is_empty(&self) -> bool {
 394        self.messages.is_empty()
 395    }
 396
 397    pub fn updated_at(&self) -> DateTime<Utc> {
 398        self.updated_at
 399    }
 400
 401    pub fn touch_updated_at(&mut self) {
 402        self.updated_at = Utc::now();
 403    }
 404
 405    pub fn summary(&self) -> Option<SharedString> {
 406        self.summary.clone()
 407    }
 408
 409    pub fn project_context(&self) -> SharedProjectContext {
 410        self.project_context.clone()
 411    }
 412
 413    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 414
 415    pub fn summary_or_default(&self) -> SharedString {
 416        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 417    }
 418
 419    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 420        let Some(current_summary) = &self.summary else {
 421            // Don't allow setting summary until generated
 422            return;
 423        };
 424
 425        let mut new_summary = new_summary.into();
 426
 427        if new_summary.is_empty() {
 428            new_summary = Self::DEFAULT_SUMMARY;
 429        }
 430
 431        if current_summary != &new_summary {
 432            self.summary = Some(new_summary);
 433            cx.emit(ThreadEvent::SummaryChanged);
 434        }
 435    }
 436
 437    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 438        self.latest_detailed_summary()
 439            .unwrap_or_else(|| self.text().into())
 440    }
 441
 442    fn latest_detailed_summary(&self) -> Option<SharedString> {
 443        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 444            Some(text.clone())
 445        } else {
 446            None
 447        }
 448    }
 449
 450    pub fn message(&self, id: MessageId) -> Option<&Message> {
 451        self.messages.iter().find(|message| message.id == id)
 452    }
 453
 454    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 455        self.messages.iter()
 456    }
 457
 458    pub fn is_generating(&self) -> bool {
 459        !self.pending_completions.is_empty() || !self.all_tools_finished()
 460    }
 461
 462    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 463        &self.tools
 464    }
 465
 466    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 467        self.tool_use
 468            .pending_tool_uses()
 469            .into_iter()
 470            .find(|tool_use| &tool_use.id == id)
 471    }
 472
 473    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 474        self.tool_use
 475            .pending_tool_uses()
 476            .into_iter()
 477            .filter(|tool_use| tool_use.status.needs_confirmation())
 478    }
 479
 480    pub fn has_pending_tool_uses(&self) -> bool {
 481        !self.tool_use.pending_tool_uses().is_empty()
 482    }
 483
 484    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 485        self.checkpoints_by_message.get(&id).cloned()
 486    }
 487
 488    pub fn restore_checkpoint(
 489        &mut self,
 490        checkpoint: ThreadCheckpoint,
 491        cx: &mut Context<Self>,
 492    ) -> Task<Result<()>> {
 493        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 494            message_id: checkpoint.message_id,
 495        });
 496        cx.emit(ThreadEvent::CheckpointChanged);
 497        cx.notify();
 498
 499        let git_store = self.project().read(cx).git_store().clone();
 500        let restore = git_store.update(cx, |git_store, cx| {
 501            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 502        });
 503
 504        cx.spawn(async move |this, cx| {
 505            let result = restore.await;
 506            this.update(cx, |this, cx| {
 507                if let Err(err) = result.as_ref() {
 508                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 509                        message_id: checkpoint.message_id,
 510                        error: err.to_string(),
 511                    });
 512                } else {
 513                    this.truncate(checkpoint.message_id, cx);
 514                    this.last_restore_checkpoint = None;
 515                }
 516                this.pending_checkpoint = None;
 517                cx.emit(ThreadEvent::CheckpointChanged);
 518                cx.notify();
 519            })?;
 520            result
 521        })
 522    }
 523
 524    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 525        let pending_checkpoint = if self.is_generating() {
 526            return;
 527        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 528            checkpoint
 529        } else {
 530            return;
 531        };
 532
 533        let git_store = self.project.read(cx).git_store().clone();
 534        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 535        cx.spawn(async move |this, cx| match final_checkpoint.await {
 536            Ok(final_checkpoint) => {
 537                let equal = git_store
 538                    .update(cx, |store, cx| {
 539                        store.compare_checkpoints(
 540                            pending_checkpoint.git_checkpoint.clone(),
 541                            final_checkpoint.clone(),
 542                            cx,
 543                        )
 544                    })?
 545                    .await
 546                    .unwrap_or(false);
 547
 548                if equal {
 549                    git_store
 550                        .update(cx, |store, cx| {
 551                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 552                        })?
 553                        .detach();
 554                } else {
 555                    this.update(cx, |this, cx| {
 556                        this.insert_checkpoint(pending_checkpoint, cx)
 557                    })?;
 558                }
 559
 560                git_store
 561                    .update(cx, |store, cx| {
 562                        store.delete_checkpoint(final_checkpoint, cx)
 563                    })?
 564                    .detach();
 565
 566                Ok(())
 567            }
 568            Err(_) => this.update(cx, |this, cx| {
 569                this.insert_checkpoint(pending_checkpoint, cx)
 570            }),
 571        })
 572        .detach();
 573    }
 574
 575    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 576        self.checkpoints_by_message
 577            .insert(checkpoint.message_id, checkpoint);
 578        cx.emit(ThreadEvent::CheckpointChanged);
 579        cx.notify();
 580    }
 581
 582    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 583        self.last_restore_checkpoint.as_ref()
 584    }
 585
 586    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 587        let Some(message_ix) = self
 588            .messages
 589            .iter()
 590            .rposition(|message| message.id == message_id)
 591        else {
 592            return;
 593        };
 594        for deleted_message in self.messages.drain(message_ix..) {
 595            self.context_by_message.remove(&deleted_message.id);
 596            self.checkpoints_by_message.remove(&deleted_message.id);
 597        }
 598        cx.notify();
 599    }
 600
 601    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 602        self.context_by_message
 603            .get(&id)
 604            .into_iter()
 605            .flat_map(|context| {
 606                context
 607                    .iter()
 608                    .filter_map(|context_id| self.context.get(&context_id))
 609            })
 610    }
 611
 612    /// Returns whether all of the tool uses have finished running.
 613    pub fn all_tools_finished(&self) -> bool {
 614        // If the only pending tool uses left are the ones with errors, then
 615        // that means that we've finished running all of the pending tools.
 616        self.tool_use
 617            .pending_tool_uses()
 618            .iter()
 619            .all(|tool_use| tool_use.status.is_error())
 620    }
 621
 622    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 623        self.tool_use.tool_uses_for_message(id, cx)
 624    }
 625
 626    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 627        self.tool_use.tool_results_for_message(id)
 628    }
 629
 630    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 631        self.tool_use.tool_result(id)
 632    }
 633
 634    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 635        Some(&self.tool_use.tool_result(id)?.content)
 636    }
 637
 638    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 639        self.tool_use.tool_result_card(id).cloned()
 640    }
 641
 642    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 643        self.tool_use.message_has_tool_results(message_id)
 644    }
 645
 646    pub fn insert_user_message(
 647        &mut self,
 648        text: impl Into<String>,
 649        context: Vec<AssistantContext>,
 650        git_checkpoint: Option<GitStoreCheckpoint>,
 651        cx: &mut Context<Self>,
 652    ) -> MessageId {
 653        let text = text.into();
 654
 655        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 656
 657        // Filter out contexts that have already been included in previous messages
 658        let new_context: Vec<_> = context
 659            .into_iter()
 660            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 661            .collect();
 662
 663        if !new_context.is_empty() {
 664            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 665                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 666                    message.context = context_string;
 667                }
 668            }
 669
 670            self.action_log.update(cx, |log, cx| {
 671                // Track all buffers added as context
 672                for ctx in &new_context {
 673                    match ctx {
 674                        AssistantContext::File(file_ctx) => {
 675                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 676                        }
 677                        AssistantContext::Directory(dir_ctx) => {
 678                            for context_buffer in &dir_ctx.context_buffers {
 679                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 680                            }
 681                        }
 682                        AssistantContext::Symbol(symbol_ctx) => {
 683                            log.buffer_added_as_context(
 684                                symbol_ctx.context_symbol.buffer.clone(),
 685                                cx,
 686                            );
 687                        }
 688                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 689                    }
 690                }
 691            });
 692        }
 693
 694        let context_ids = new_context
 695            .iter()
 696            .map(|context| context.id())
 697            .collect::<Vec<_>>();
 698        self.context.extend(
 699            new_context
 700                .into_iter()
 701                .map(|context| (context.id(), context)),
 702        );
 703        self.context_by_message.insert(message_id, context_ids);
 704
 705        if let Some(git_checkpoint) = git_checkpoint {
 706            self.pending_checkpoint = Some(ThreadCheckpoint {
 707                message_id,
 708                git_checkpoint,
 709            });
 710        }
 711
 712        self.auto_capture_telemetry(cx);
 713
 714        message_id
 715    }
 716
 717    pub fn insert_message(
 718        &mut self,
 719        role: Role,
 720        segments: Vec<MessageSegment>,
 721        cx: &mut Context<Self>,
 722    ) -> MessageId {
 723        let id = self.next_message_id.post_inc();
 724        self.messages.push(Message {
 725            id,
 726            role,
 727            segments,
 728            context: String::new(),
 729        });
 730        self.touch_updated_at();
 731        cx.emit(ThreadEvent::MessageAdded(id));
 732        id
 733    }
 734
 735    pub fn edit_message(
 736        &mut self,
 737        id: MessageId,
 738        new_role: Role,
 739        new_segments: Vec<MessageSegment>,
 740        cx: &mut Context<Self>,
 741    ) -> bool {
 742        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 743            return false;
 744        };
 745        message.role = new_role;
 746        message.segments = new_segments;
 747        self.touch_updated_at();
 748        cx.emit(ThreadEvent::MessageEdited(id));
 749        true
 750    }
 751
 752    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 753        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 754            return false;
 755        };
 756        self.messages.remove(index);
 757        self.context_by_message.remove(&id);
 758        self.touch_updated_at();
 759        cx.emit(ThreadEvent::MessageDeleted(id));
 760        true
 761    }
 762
 763    /// Returns the representation of this [`Thread`] in a textual form.
 764    ///
 765    /// This is the representation we use when attaching a thread as context to another thread.
 766    pub fn text(&self) -> String {
 767        let mut text = String::new();
 768
 769        for message in &self.messages {
 770            text.push_str(match message.role {
 771                language_model::Role::User => "User:",
 772                language_model::Role::Assistant => "Assistant:",
 773                language_model::Role::System => "System:",
 774            });
 775            text.push('\n');
 776
 777            for segment in &message.segments {
 778                match segment {
 779                    MessageSegment::Text(content) => text.push_str(content),
 780                    MessageSegment::Thinking(content) => {
 781                        text.push_str(&format!("<think>{}</think>", content))
 782                    }
 783                }
 784            }
 785            text.push('\n');
 786        }
 787
 788        text
 789    }
 790
 791    /// Serializes this thread into a format for storage or telemetry.
 792    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 793        let initial_project_snapshot = self.initial_project_snapshot.clone();
 794        cx.spawn(async move |this, cx| {
 795            let initial_project_snapshot = initial_project_snapshot.await;
 796            this.read_with(cx, |this, cx| SerializedThread {
 797                version: SerializedThread::VERSION.to_string(),
 798                summary: this.summary_or_default(),
 799                updated_at: this.updated_at(),
 800                messages: this
 801                    .messages()
 802                    .map(|message| SerializedMessage {
 803                        id: message.id,
 804                        role: message.role,
 805                        segments: message
 806                            .segments
 807                            .iter()
 808                            .map(|segment| match segment {
 809                                MessageSegment::Text(text) => {
 810                                    SerializedMessageSegment::Text { text: text.clone() }
 811                                }
 812                                MessageSegment::Thinking(text) => {
 813                                    SerializedMessageSegment::Thinking { text: text.clone() }
 814                                }
 815                            })
 816                            .collect(),
 817                        tool_uses: this
 818                            .tool_uses_for_message(message.id, cx)
 819                            .into_iter()
 820                            .map(|tool_use| SerializedToolUse {
 821                                id: tool_use.id,
 822                                name: tool_use.name,
 823                                input: tool_use.input,
 824                            })
 825                            .collect(),
 826                        tool_results: this
 827                            .tool_results_for_message(message.id)
 828                            .into_iter()
 829                            .map(|tool_result| SerializedToolResult {
 830                                tool_use_id: tool_result.tool_use_id.clone(),
 831                                is_error: tool_result.is_error,
 832                                content: tool_result.content.clone(),
 833                            })
 834                            .collect(),
 835                        context: message.context.clone(),
 836                    })
 837                    .collect(),
 838                initial_project_snapshot,
 839                cumulative_token_usage: this.cumulative_token_usage,
 840                detailed_summary_state: this.detailed_summary_state.clone(),
 841                exceeded_window_error: this.exceeded_window_error.clone(),
 842            })
 843        })
 844    }
 845
 846    pub fn send_to_model(
 847        &mut self,
 848        model: Arc<dyn LanguageModel>,
 849        request_kind: RequestKind,
 850        cx: &mut Context<Self>,
 851    ) {
 852        let mut request = self.to_completion_request(request_kind, cx);
 853        if model.supports_tools() {
 854            request.tools = {
 855                let mut tools = Vec::new();
 856                tools.extend(
 857                    self.tools()
 858                        .read(cx)
 859                        .enabled_tools(cx)
 860                        .into_iter()
 861                        .filter_map(|tool| {
 862                            // Skip tools that cannot be supported
 863                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 864                            Some(LanguageModelRequestTool {
 865                                name: tool.name(),
 866                                description: tool.description(),
 867                                input_schema,
 868                            })
 869                        }),
 870                );
 871
 872                tools
 873            };
 874        }
 875
 876        self.stream_completion(request, model, cx);
 877    }
 878
 879    pub fn used_tools_since_last_user_message(&self) -> bool {
 880        for message in self.messages.iter().rev() {
 881            if self.tool_use.message_has_tool_results(message.id) {
 882                return true;
 883            } else if message.role == Role::User {
 884                return false;
 885            }
 886        }
 887
 888        false
 889    }
 890
 891    pub fn to_completion_request(
 892        &self,
 893        request_kind: RequestKind,
 894        cx: &App,
 895    ) -> LanguageModelRequest {
 896        let mut request = LanguageModelRequest {
 897            messages: vec![],
 898            tools: Vec::new(),
 899            stop: Vec::new(),
 900            temperature: None,
 901        };
 902
 903        if let Some(project_context) = self.project_context.borrow().as_ref() {
 904            if let Some(system_prompt) = self
 905                .prompt_builder
 906                .generate_assistant_system_prompt(project_context)
 907                .context("failed to generate assistant system prompt")
 908                .log_err()
 909            {
 910                request.messages.push(LanguageModelRequestMessage {
 911                    role: Role::System,
 912                    content: vec![MessageContent::Text(system_prompt)],
 913                    cache: true,
 914                });
 915            }
 916        } else {
 917            log::error!("project_context not set.")
 918        }
 919
 920        for message in &self.messages {
 921            let mut request_message = LanguageModelRequestMessage {
 922                role: message.role,
 923                content: Vec::new(),
 924                cache: false,
 925            };
 926
 927            match request_kind {
 928                RequestKind::Chat => {
 929                    self.tool_use
 930                        .attach_tool_results(message.id, &mut request_message);
 931                }
 932                RequestKind::Summarize => {
 933                    // We don't care about tool use during summarization.
 934                    if self.tool_use.message_has_tool_results(message.id) {
 935                        continue;
 936                    }
 937                }
 938            }
 939
 940            if !message.segments.is_empty() {
 941                request_message
 942                    .content
 943                    .push(MessageContent::Text(message.to_string()));
 944            }
 945
 946            match request_kind {
 947                RequestKind::Chat => {
 948                    self.tool_use
 949                        .attach_tool_uses(message.id, &mut request_message);
 950                }
 951                RequestKind::Summarize => {
 952                    // We don't care about tool use during summarization.
 953                }
 954            };
 955
 956            request.messages.push(request_message);
 957        }
 958
 959        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 960        if let Some(last) = request.messages.last_mut() {
 961            last.cache = true;
 962        }
 963
 964        self.attached_tracked_files_state(&mut request.messages, cx);
 965
 966        request
 967    }
 968
 969    fn attached_tracked_files_state(
 970        &self,
 971        messages: &mut Vec<LanguageModelRequestMessage>,
 972        cx: &App,
 973    ) {
 974        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 975
 976        let mut stale_message = String::new();
 977
 978        let action_log = self.action_log.read(cx);
 979
 980        for stale_file in action_log.stale_buffers(cx) {
 981            let Some(file) = stale_file.read(cx).file() else {
 982                continue;
 983            };
 984
 985            if stale_message.is_empty() {
 986                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
 987            }
 988
 989            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 990        }
 991
 992        let mut content = Vec::with_capacity(2);
 993
 994        if !stale_message.is_empty() {
 995            content.push(stale_message.into());
 996        }
 997
 998        if action_log.has_edited_files_since_project_diagnostics_check() {
 999            content.push(
1000                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1001                and fix all errors AND warnings you introduced! \
1002                DO NOT mention you're going to do this until you're done."
1003                    .into(),
1004            );
1005        }
1006
1007        if !content.is_empty() {
1008            let context_message = LanguageModelRequestMessage {
1009                role: Role::User,
1010                content,
1011                cache: false,
1012            };
1013
1014            messages.push(context_message);
1015        }
1016    }
1017
1018    pub fn stream_completion(
1019        &mut self,
1020        request: LanguageModelRequest,
1021        model: Arc<dyn LanguageModel>,
1022        cx: &mut Context<Self>,
1023    ) {
1024        let pending_completion_id = post_inc(&mut self.completion_count);
1025
1026        let task = cx.spawn(async move |thread, cx| {
1027            let stream = model.stream_completion(request, &cx);
1028            let initial_token_usage =
1029                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1030            let stream_completion = async {
1031                let mut events = stream.await?;
1032                let mut stop_reason = StopReason::EndTurn;
1033                let mut current_token_usage = TokenUsage::default();
1034
1035                while let Some(event) = events.next().await {
1036                    let event = event?;
1037
1038                    thread.update(cx, |thread, cx| {
1039                        match event {
1040                            LanguageModelCompletionEvent::StartMessage { .. } => {
1041                                thread.insert_message(
1042                                    Role::Assistant,
1043                                    vec![MessageSegment::Text(String::new())],
1044                                    cx,
1045                                );
1046                            }
1047                            LanguageModelCompletionEvent::Stop(reason) => {
1048                                stop_reason = reason;
1049                            }
1050                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1051                                thread.cumulative_token_usage = thread.cumulative_token_usage
1052                                    + token_usage
1053                                    - current_token_usage;
1054                                current_token_usage = token_usage;
1055                            }
1056                            LanguageModelCompletionEvent::Text(chunk) => {
1057                                if let Some(last_message) = thread.messages.last_mut() {
1058                                    if last_message.role == Role::Assistant {
1059                                        last_message.push_text(&chunk);
1060                                        cx.emit(ThreadEvent::StreamedAssistantText(
1061                                            last_message.id,
1062                                            chunk,
1063                                        ));
1064                                    } else {
1065                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1066                                        // of a new Assistant response.
1067                                        //
1068                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1069                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1070                                        thread.insert_message(
1071                                            Role::Assistant,
1072                                            vec![MessageSegment::Text(chunk.to_string())],
1073                                            cx,
1074                                        );
1075                                    };
1076                                }
1077                            }
1078                            LanguageModelCompletionEvent::Thinking(chunk) => {
1079                                if let Some(last_message) = thread.messages.last_mut() {
1080                                    if last_message.role == Role::Assistant {
1081                                        last_message.push_thinking(&chunk);
1082                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1083                                            last_message.id,
1084                                            chunk,
1085                                        ));
1086                                    } else {
1087                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1088                                        // of a new Assistant response.
1089                                        //
1090                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1091                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1092                                        thread.insert_message(
1093                                            Role::Assistant,
1094                                            vec![MessageSegment::Thinking(chunk.to_string())],
1095                                            cx,
1096                                        );
1097                                    };
1098                                }
1099                            }
1100                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1101                                let last_assistant_message_id = thread
1102                                    .messages
1103                                    .iter_mut()
1104                                    .rfind(|message| message.role == Role::Assistant)
1105                                    .map(|message| message.id)
1106                                    .unwrap_or_else(|| {
1107                                        thread.insert_message(Role::Assistant, vec![], cx)
1108                                    });
1109
1110                                thread.tool_use.request_tool_use(
1111                                    last_assistant_message_id,
1112                                    tool_use,
1113                                    cx,
1114                                );
1115                            }
1116                        }
1117
1118                        thread.touch_updated_at();
1119                        cx.emit(ThreadEvent::StreamedCompletion);
1120                        cx.notify();
1121
1122                        thread.auto_capture_telemetry(cx);
1123                    })?;
1124
1125                    smol::future::yield_now().await;
1126                }
1127
1128                thread.update(cx, |thread, cx| {
1129                    thread
1130                        .pending_completions
1131                        .retain(|completion| completion.id != pending_completion_id);
1132
1133                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1134                        thread.summarize(cx);
1135                    }
1136                })?;
1137
1138                anyhow::Ok(stop_reason)
1139            };
1140
1141            let result = stream_completion.await;
1142
1143            thread
1144                .update(cx, |thread, cx| {
1145                    thread.finalize_pending_checkpoint(cx);
1146                    match result.as_ref() {
1147                        Ok(stop_reason) => match stop_reason {
1148                            StopReason::ToolUse => {
1149                                let tool_uses = thread.use_pending_tools(cx);
1150                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1151                            }
1152                            StopReason::EndTurn => {}
1153                            StopReason::MaxTokens => {}
1154                        },
1155                        Err(error) => {
1156                            if error.is::<PaymentRequiredError>() {
1157                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1158                            } else if error.is::<MaxMonthlySpendReachedError>() {
1159                                cx.emit(ThreadEvent::ShowError(
1160                                    ThreadError::MaxMonthlySpendReached,
1161                                ));
1162                            } else if let Some(error) =
1163                                error.downcast_ref::<ModelRequestLimitReachedError>()
1164                            {
1165                                cx.emit(ThreadEvent::ShowError(
1166                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1167                                ));
1168                            } else if let Some(known_error) =
1169                                error.downcast_ref::<LanguageModelKnownError>()
1170                            {
1171                                match known_error {
1172                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1173                                        tokens,
1174                                    } => {
1175                                        thread.exceeded_window_error = Some(ExceededWindowError {
1176                                            model_id: model.id(),
1177                                            token_count: *tokens,
1178                                        });
1179                                        cx.notify();
1180                                    }
1181                                }
1182                            } else {
1183                                let error_message = error
1184                                    .chain()
1185                                    .map(|err| err.to_string())
1186                                    .collect::<Vec<_>>()
1187                                    .join("\n");
1188                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1189                                    header: "Error interacting with language model".into(),
1190                                    message: SharedString::from(error_message.clone()),
1191                                }));
1192                            }
1193
1194                            thread.cancel_last_completion(cx);
1195                        }
1196                    }
1197                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1198
1199                    thread.auto_capture_telemetry(cx);
1200
1201                    if let Ok(initial_usage) = initial_token_usage {
1202                        let usage = thread.cumulative_token_usage - initial_usage;
1203
1204                        telemetry::event!(
1205                            "Assistant Thread Completion",
1206                            thread_id = thread.id().to_string(),
1207                            model = model.telemetry_id(),
1208                            model_provider = model.provider_id().to_string(),
1209                            input_tokens = usage.input_tokens,
1210                            output_tokens = usage.output_tokens,
1211                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1212                            cache_read_input_tokens = usage.cache_read_input_tokens,
1213                        );
1214                    }
1215                })
1216                .ok();
1217        });
1218
1219        self.pending_completions.push(PendingCompletion {
1220            id: pending_completion_id,
1221            _task: task,
1222        });
1223    }
1224
1225    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1226        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1227            return;
1228        };
1229
1230        if !model.provider.is_authenticated(cx) {
1231            return;
1232        }
1233
1234        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1235        request.messages.push(LanguageModelRequestMessage {
1236            role: Role::User,
1237            content: vec![
1238                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1239                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1240                 If the conversation is about a specific subject, include it in the title. \
1241                 Be descriptive. DO NOT speak in the first person."
1242                    .into(),
1243            ],
1244            cache: false,
1245        });
1246
1247        self.pending_summary = cx.spawn(async move |this, cx| {
1248            async move {
1249                let stream = model.model.stream_completion_text(request, &cx);
1250                let mut messages = stream.await?;
1251
1252                let mut new_summary = String::new();
1253                while let Some(message) = messages.stream.next().await {
1254                    let text = message?;
1255                    let mut lines = text.lines();
1256                    new_summary.extend(lines.next());
1257
1258                    // Stop if the LLM generated multiple lines.
1259                    if lines.next().is_some() {
1260                        break;
1261                    }
1262                }
1263
1264                this.update(cx, |this, cx| {
1265                    if !new_summary.is_empty() {
1266                        this.summary = Some(new_summary.into());
1267                    }
1268
1269                    cx.emit(ThreadEvent::SummaryGenerated);
1270                })?;
1271
1272                anyhow::Ok(())
1273            }
1274            .log_err()
1275            .await
1276        });
1277    }
1278
1279    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1280        let last_message_id = self.messages.last().map(|message| message.id)?;
1281
1282        match &self.detailed_summary_state {
1283            DetailedSummaryState::Generating { message_id, .. }
1284            | DetailedSummaryState::Generated { message_id, .. }
1285                if *message_id == last_message_id =>
1286            {
1287                // Already up-to-date
1288                return None;
1289            }
1290            _ => {}
1291        }
1292
1293        let ConfiguredModel { model, provider } =
1294            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1295
1296        if !provider.is_authenticated(cx) {
1297            return None;
1298        }
1299
1300        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1301
1302        request.messages.push(LanguageModelRequestMessage {
1303            role: Role::User,
1304            content: vec![
1305                "Generate a detailed summary of this conversation. Include:\n\
1306                1. A brief overview of what was discussed\n\
1307                2. Key facts or information discovered\n\
1308                3. Outcomes or conclusions reached\n\
1309                4. Any action items or next steps if any\n\
1310                Format it in Markdown with headings and bullet points."
1311                    .into(),
1312            ],
1313            cache: false,
1314        });
1315
1316        let task = cx.spawn(async move |thread, cx| {
1317            let stream = model.stream_completion_text(request, &cx);
1318            let Some(mut messages) = stream.await.log_err() else {
1319                thread
1320                    .update(cx, |this, _cx| {
1321                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1322                    })
1323                    .log_err();
1324
1325                return;
1326            };
1327
1328            let mut new_detailed_summary = String::new();
1329
1330            while let Some(chunk) = messages.stream.next().await {
1331                if let Some(chunk) = chunk.log_err() {
1332                    new_detailed_summary.push_str(&chunk);
1333                }
1334            }
1335
1336            thread
1337                .update(cx, |this, _cx| {
1338                    this.detailed_summary_state = DetailedSummaryState::Generated {
1339                        text: new_detailed_summary.into(),
1340                        message_id: last_message_id,
1341                    };
1342                })
1343                .log_err();
1344        });
1345
1346        self.detailed_summary_state = DetailedSummaryState::Generating {
1347            message_id: last_message_id,
1348        };
1349
1350        Some(task)
1351    }
1352
1353    pub fn is_generating_detailed_summary(&self) -> bool {
1354        matches!(
1355            self.detailed_summary_state,
1356            DetailedSummaryState::Generating { .. }
1357        )
1358    }
1359
1360    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1361        self.auto_capture_telemetry(cx);
1362        let request = self.to_completion_request(RequestKind::Chat, cx);
1363        let messages = Arc::new(request.messages);
1364        let pending_tool_uses = self
1365            .tool_use
1366            .pending_tool_uses()
1367            .into_iter()
1368            .filter(|tool_use| tool_use.status.is_idle())
1369            .cloned()
1370            .collect::<Vec<_>>();
1371
1372        for tool_use in pending_tool_uses.iter() {
1373            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1374                if tool.needs_confirmation(&tool_use.input, cx)
1375                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1376                {
1377                    self.tool_use.confirm_tool_use(
1378                        tool_use.id.clone(),
1379                        tool_use.ui_text.clone(),
1380                        tool_use.input.clone(),
1381                        messages.clone(),
1382                        tool,
1383                    );
1384                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1385                } else {
1386                    self.run_tool(
1387                        tool_use.id.clone(),
1388                        tool_use.ui_text.clone(),
1389                        tool_use.input.clone(),
1390                        &messages,
1391                        tool,
1392                        cx,
1393                    );
1394                }
1395            }
1396        }
1397
1398        pending_tool_uses
1399    }
1400
1401    pub fn run_tool(
1402        &mut self,
1403        tool_use_id: LanguageModelToolUseId,
1404        ui_text: impl Into<SharedString>,
1405        input: serde_json::Value,
1406        messages: &[LanguageModelRequestMessage],
1407        tool: Arc<dyn Tool>,
1408        cx: &mut Context<Thread>,
1409    ) {
1410        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1411        self.tool_use
1412            .run_pending_tool(tool_use_id, ui_text.into(), task);
1413    }
1414
1415    fn spawn_tool_use(
1416        &mut self,
1417        tool_use_id: LanguageModelToolUseId,
1418        messages: &[LanguageModelRequestMessage],
1419        input: serde_json::Value,
1420        tool: Arc<dyn Tool>,
1421        cx: &mut Context<Thread>,
1422    ) -> Task<()> {
1423        let tool_name: Arc<str> = tool.name().into();
1424
1425        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1426            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1427        } else {
1428            tool.run(
1429                input,
1430                messages,
1431                self.project.clone(),
1432                self.action_log.clone(),
1433                cx,
1434            )
1435        };
1436
1437        // Store the card separately if it exists
1438        if let Some(card) = tool_result.card.clone() {
1439            self.tool_use
1440                .insert_tool_result_card(tool_use_id.clone(), card);
1441        }
1442
1443        cx.spawn({
1444            async move |thread: WeakEntity<Thread>, cx| {
1445                let output = tool_result.output.await;
1446
1447                thread
1448                    .update(cx, |thread, cx| {
1449                        let pending_tool_use = thread.tool_use.insert_tool_output(
1450                            tool_use_id.clone(),
1451                            tool_name,
1452                            output,
1453                            cx,
1454                        );
1455                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1456                    })
1457                    .ok();
1458            }
1459        })
1460    }
1461
1462    fn tool_finished(
1463        &mut self,
1464        tool_use_id: LanguageModelToolUseId,
1465        pending_tool_use: Option<PendingToolUse>,
1466        canceled: bool,
1467        cx: &mut Context<Self>,
1468    ) {
1469        if self.all_tools_finished() {
1470            let model_registry = LanguageModelRegistry::read_global(cx);
1471            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1472                self.attach_tool_results(cx);
1473                if !canceled {
1474                    self.send_to_model(model, RequestKind::Chat, cx);
1475                }
1476            }
1477        }
1478
1479        cx.emit(ThreadEvent::ToolFinished {
1480            tool_use_id,
1481            pending_tool_use,
1482        });
1483    }
1484
1485    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1486        // Insert a user message to contain the tool results.
1487        self.insert_user_message(
1488            // TODO: Sending up a user message without any content results in the model sending back
1489            // responses that also don't have any content. We currently don't handle this case well,
1490            // so for now we provide some text to keep the model on track.
1491            "Here are the tool results.",
1492            Vec::new(),
1493            None,
1494            cx,
1495        );
1496    }
1497
1498    /// Cancels the last pending completion, if there are any pending.
1499    ///
1500    /// Returns whether a completion was canceled.
1501    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1502        let canceled = if self.pending_completions.pop().is_some() {
1503            true
1504        } else {
1505            let mut canceled = false;
1506            for pending_tool_use in self.tool_use.cancel_pending() {
1507                canceled = true;
1508                self.tool_finished(
1509                    pending_tool_use.id.clone(),
1510                    Some(pending_tool_use),
1511                    true,
1512                    cx,
1513                );
1514            }
1515            canceled
1516        };
1517        self.finalize_pending_checkpoint(cx);
1518        canceled
1519    }
1520
1521    pub fn feedback(&self) -> Option<ThreadFeedback> {
1522        self.feedback
1523    }
1524
1525    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1526        self.message_feedback.get(&message_id).copied()
1527    }
1528
1529    pub fn report_message_feedback(
1530        &mut self,
1531        message_id: MessageId,
1532        feedback: ThreadFeedback,
1533        cx: &mut Context<Self>,
1534    ) -> Task<Result<()>> {
1535        if self.message_feedback.get(&message_id) == Some(&feedback) {
1536            return Task::ready(Ok(()));
1537        }
1538
1539        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1540        let serialized_thread = self.serialize(cx);
1541        let thread_id = self.id().clone();
1542        let client = self.project.read(cx).client();
1543
1544        let enabled_tool_names: Vec<String> = self
1545            .tools()
1546            .read(cx)
1547            .enabled_tools(cx)
1548            .iter()
1549            .map(|tool| tool.name().to_string())
1550            .collect();
1551
1552        self.message_feedback.insert(message_id, feedback);
1553
1554        cx.notify();
1555
1556        let message_content = self
1557            .message(message_id)
1558            .map(|msg| msg.to_string())
1559            .unwrap_or_default();
1560
1561        cx.background_spawn(async move {
1562            let final_project_snapshot = final_project_snapshot.await;
1563            let serialized_thread = serialized_thread.await?;
1564            let thread_data =
1565                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1566
1567            let rating = match feedback {
1568                ThreadFeedback::Positive => "positive",
1569                ThreadFeedback::Negative => "negative",
1570            };
1571            telemetry::event!(
1572                "Assistant Thread Rated",
1573                rating,
1574                thread_id,
1575                enabled_tool_names,
1576                message_id = message_id.0,
1577                message_content,
1578                thread_data,
1579                final_project_snapshot
1580            );
1581            client.telemetry().flush_events();
1582
1583            Ok(())
1584        })
1585    }
1586
1587    pub fn report_feedback(
1588        &mut self,
1589        feedback: ThreadFeedback,
1590        cx: &mut Context<Self>,
1591    ) -> Task<Result<()>> {
1592        let last_assistant_message_id = self
1593            .messages
1594            .iter()
1595            .rev()
1596            .find(|msg| msg.role == Role::Assistant)
1597            .map(|msg| msg.id);
1598
1599        if let Some(message_id) = last_assistant_message_id {
1600            self.report_message_feedback(message_id, feedback, cx)
1601        } else {
1602            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1603            let serialized_thread = self.serialize(cx);
1604            let thread_id = self.id().clone();
1605            let client = self.project.read(cx).client();
1606            self.feedback = Some(feedback);
1607            cx.notify();
1608
1609            cx.background_spawn(async move {
1610                let final_project_snapshot = final_project_snapshot.await;
1611                let serialized_thread = serialized_thread.await?;
1612                let thread_data = serde_json::to_value(serialized_thread)
1613                    .unwrap_or_else(|_| serde_json::Value::Null);
1614
1615                let rating = match feedback {
1616                    ThreadFeedback::Positive => "positive",
1617                    ThreadFeedback::Negative => "negative",
1618                };
1619                telemetry::event!(
1620                    "Assistant Thread Rated",
1621                    rating,
1622                    thread_id,
1623                    thread_data,
1624                    final_project_snapshot
1625                );
1626                client.telemetry().flush_events();
1627
1628                Ok(())
1629            })
1630        }
1631    }
1632
1633    /// Create a snapshot of the current project state including git information and unsaved buffers.
1634    fn project_snapshot(
1635        project: Entity<Project>,
1636        cx: &mut Context<Self>,
1637    ) -> Task<Arc<ProjectSnapshot>> {
1638        let git_store = project.read(cx).git_store().clone();
1639        let worktree_snapshots: Vec<_> = project
1640            .read(cx)
1641            .visible_worktrees(cx)
1642            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1643            .collect();
1644
1645        cx.spawn(async move |_, cx| {
1646            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1647
1648            let mut unsaved_buffers = Vec::new();
1649            cx.update(|app_cx| {
1650                let buffer_store = project.read(app_cx).buffer_store();
1651                for buffer_handle in buffer_store.read(app_cx).buffers() {
1652                    let buffer = buffer_handle.read(app_cx);
1653                    if buffer.is_dirty() {
1654                        if let Some(file) = buffer.file() {
1655                            let path = file.path().to_string_lossy().to_string();
1656                            unsaved_buffers.push(path);
1657                        }
1658                    }
1659                }
1660            })
1661            .ok();
1662
1663            Arc::new(ProjectSnapshot {
1664                worktree_snapshots,
1665                unsaved_buffer_paths: unsaved_buffers,
1666                timestamp: Utc::now(),
1667            })
1668        })
1669    }
1670
1671    fn worktree_snapshot(
1672        worktree: Entity<project::Worktree>,
1673        git_store: Entity<GitStore>,
1674        cx: &App,
1675    ) -> Task<WorktreeSnapshot> {
1676        cx.spawn(async move |cx| {
1677            // Get worktree path and snapshot
1678            let worktree_info = cx.update(|app_cx| {
1679                let worktree = worktree.read(app_cx);
1680                let path = worktree.abs_path().to_string_lossy().to_string();
1681                let snapshot = worktree.snapshot();
1682                (path, snapshot)
1683            });
1684
1685            let Ok((worktree_path, _snapshot)) = worktree_info else {
1686                return WorktreeSnapshot {
1687                    worktree_path: String::new(),
1688                    git_state: None,
1689                };
1690            };
1691
1692            let git_state = git_store
1693                .update(cx, |git_store, cx| {
1694                    git_store
1695                        .repositories()
1696                        .values()
1697                        .find(|repo| {
1698                            repo.read(cx)
1699                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1700                                .is_some()
1701                        })
1702                        .cloned()
1703                })
1704                .ok()
1705                .flatten()
1706                .map(|repo| {
1707                    repo.update(cx, |repo, _| {
1708                        let current_branch =
1709                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1710                        repo.send_job(None, |state, _| async move {
1711                            let RepositoryState::Local { backend, .. } = state else {
1712                                return GitState {
1713                                    remote_url: None,
1714                                    head_sha: None,
1715                                    current_branch,
1716                                    diff: None,
1717                                };
1718                            };
1719
1720                            let remote_url = backend.remote_url("origin");
1721                            let head_sha = backend.head_sha();
1722                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1723
1724                            GitState {
1725                                remote_url,
1726                                head_sha,
1727                                current_branch,
1728                                diff,
1729                            }
1730                        })
1731                    })
1732                });
1733
1734            let git_state = match git_state {
1735                Some(git_state) => match git_state.ok() {
1736                    Some(git_state) => git_state.await.ok(),
1737                    None => None,
1738                },
1739                None => None,
1740            };
1741
1742            WorktreeSnapshot {
1743                worktree_path,
1744                git_state,
1745            }
1746        })
1747    }
1748
1749    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1750        let mut markdown = Vec::new();
1751
1752        if let Some(summary) = self.summary() {
1753            writeln!(markdown, "# {summary}\n")?;
1754        };
1755
1756        for message in self.messages() {
1757            writeln!(
1758                markdown,
1759                "## {role}\n",
1760                role = match message.role {
1761                    Role::User => "User",
1762                    Role::Assistant => "Assistant",
1763                    Role::System => "System",
1764                }
1765            )?;
1766
1767            if !message.context.is_empty() {
1768                writeln!(markdown, "{}", message.context)?;
1769            }
1770
1771            for segment in &message.segments {
1772                match segment {
1773                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1774                    MessageSegment::Thinking(text) => {
1775                        writeln!(markdown, "<think>{}</think>\n", text)?
1776                    }
1777                }
1778            }
1779
1780            for tool_use in self.tool_uses_for_message(message.id, cx) {
1781                writeln!(
1782                    markdown,
1783                    "**Use Tool: {} ({})**",
1784                    tool_use.name, tool_use.id
1785                )?;
1786                writeln!(markdown, "```json")?;
1787                writeln!(
1788                    markdown,
1789                    "{}",
1790                    serde_json::to_string_pretty(&tool_use.input)?
1791                )?;
1792                writeln!(markdown, "```")?;
1793            }
1794
1795            for tool_result in self.tool_results_for_message(message.id) {
1796                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1797                if tool_result.is_error {
1798                    write!(markdown, " (Error)")?;
1799                }
1800
1801                writeln!(markdown, "**\n")?;
1802                writeln!(markdown, "{}", tool_result.content)?;
1803            }
1804        }
1805
1806        Ok(String::from_utf8_lossy(&markdown).to_string())
1807    }
1808
1809    pub fn keep_edits_in_range(
1810        &mut self,
1811        buffer: Entity<language::Buffer>,
1812        buffer_range: Range<language::Anchor>,
1813        cx: &mut Context<Self>,
1814    ) {
1815        self.action_log.update(cx, |action_log, cx| {
1816            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1817        });
1818    }
1819
1820    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1821        self.action_log
1822            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1823    }
1824
1825    pub fn reject_edits_in_ranges(
1826        &mut self,
1827        buffer: Entity<language::Buffer>,
1828        buffer_ranges: Vec<Range<language::Anchor>>,
1829        cx: &mut Context<Self>,
1830    ) -> Task<Result<()>> {
1831        self.action_log.update(cx, |action_log, cx| {
1832            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1833        })
1834    }
1835
1836    pub fn action_log(&self) -> &Entity<ActionLog> {
1837        &self.action_log
1838    }
1839
1840    pub fn project(&self) -> &Entity<Project> {
1841        &self.project
1842    }
1843
1844    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1845        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1846            return;
1847        }
1848
1849        let now = Instant::now();
1850        if let Some(last) = self.last_auto_capture_at {
1851            if now.duration_since(last).as_secs() < 10 {
1852                return;
1853            }
1854        }
1855
1856        self.last_auto_capture_at = Some(now);
1857
1858        let thread_id = self.id().clone();
1859        let github_login = self
1860            .project
1861            .read(cx)
1862            .user_store()
1863            .read(cx)
1864            .current_user()
1865            .map(|user| user.github_login.clone());
1866        let client = self.project.read(cx).client().clone();
1867        let serialize_task = self.serialize(cx);
1868
1869        cx.background_executor()
1870            .spawn(async move {
1871                if let Ok(serialized_thread) = serialize_task.await {
1872                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1873                        telemetry::event!(
1874                            "Agent Thread Auto-Captured",
1875                            thread_id = thread_id.to_string(),
1876                            thread_data = thread_data,
1877                            auto_capture_reason = "tracked_user",
1878                            github_login = github_login
1879                        );
1880
1881                        client.telemetry().flush_events();
1882                    }
1883                }
1884            })
1885            .detach();
1886    }
1887
1888    pub fn cumulative_token_usage(&self) -> TokenUsage {
1889        self.cumulative_token_usage
1890    }
1891
1892    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1893        let model_registry = LanguageModelRegistry::read_global(cx);
1894        let Some(model) = model_registry.default_model() else {
1895            return TotalTokenUsage::default();
1896        };
1897
1898        let max = model.model.max_token_count();
1899
1900        if let Some(exceeded_error) = &self.exceeded_window_error {
1901            if model.model.id() == exceeded_error.model_id {
1902                return TotalTokenUsage {
1903                    total: exceeded_error.token_count,
1904                    max,
1905                    ratio: TokenUsageRatio::Exceeded,
1906                };
1907            }
1908        }
1909
1910        #[cfg(debug_assertions)]
1911        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1912            .unwrap_or("0.8".to_string())
1913            .parse()
1914            .unwrap();
1915        #[cfg(not(debug_assertions))]
1916        let warning_threshold: f32 = 0.8;
1917
1918        let total = self.cumulative_token_usage.total_tokens() as usize;
1919
1920        let ratio = if total >= max {
1921            TokenUsageRatio::Exceeded
1922        } else if total as f32 / max as f32 >= warning_threshold {
1923            TokenUsageRatio::Warning
1924        } else {
1925            TokenUsageRatio::Normal
1926        };
1927
1928        TotalTokenUsage { total, max, ratio }
1929    }
1930
1931    pub fn deny_tool_use(
1932        &mut self,
1933        tool_use_id: LanguageModelToolUseId,
1934        tool_name: Arc<str>,
1935        cx: &mut Context<Self>,
1936    ) {
1937        let err = Err(anyhow::anyhow!(
1938            "Permission to run tool action denied by user"
1939        ));
1940
1941        self.tool_use
1942            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1943        self.tool_finished(tool_use_id.clone(), None, true, cx);
1944    }
1945}
1946
1947#[derive(Debug, Clone, Error)]
1948pub enum ThreadError {
1949    #[error("Payment required")]
1950    PaymentRequired,
1951    #[error("Max monthly spend reached")]
1952    MaxMonthlySpendReached,
1953    #[error("Model request limit reached")]
1954    ModelRequestLimitReached { plan: Plan },
1955    #[error("Message {header}: {message}")]
1956    Message {
1957        header: SharedString,
1958        message: SharedString,
1959    },
1960}
1961
1962#[derive(Debug, Clone)]
1963pub enum ThreadEvent {
1964    ShowError(ThreadError),
1965    StreamedCompletion,
1966    StreamedAssistantText(MessageId, String),
1967    StreamedAssistantThinking(MessageId, String),
1968    Stopped(Result<StopReason, Arc<anyhow::Error>>),
1969    MessageAdded(MessageId),
1970    MessageEdited(MessageId),
1971    MessageDeleted(MessageId),
1972    SummaryGenerated,
1973    SummaryChanged,
1974    UsePendingTools {
1975        tool_uses: Vec<PendingToolUse>,
1976    },
1977    ToolFinished {
1978        #[allow(unused)]
1979        tool_use_id: LanguageModelToolUseId,
1980        /// The pending tool use that corresponds to this tool.
1981        pending_tool_use: Option<PendingToolUse>,
1982    },
1983    CheckpointChanged,
1984    ToolConfirmationNeeded,
1985}
1986
1987impl EventEmitter<ThreadEvent> for Thread {}
1988
1989struct PendingCompletion {
1990    id: usize,
1991    _task: Task<()>,
1992}
1993
1994#[cfg(test)]
1995mod tests {
1996    use super::*;
1997    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1998    use assistant_settings::AssistantSettings;
1999    use context_server::ContextServerSettings;
2000    use editor::EditorSettings;
2001    use gpui::TestAppContext;
2002    use project::{FakeFs, Project};
2003    use prompt_store::PromptBuilder;
2004    use serde_json::json;
2005    use settings::{Settings, SettingsStore};
2006    use std::sync::Arc;
2007    use theme::ThemeSettings;
2008    use util::path;
2009    use workspace::Workspace;
2010
2011    #[gpui::test]
2012    async fn test_message_with_context(cx: &mut TestAppContext) {
2013        init_test_settings(cx);
2014
2015        let project = create_test_project(
2016            cx,
2017            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2018        )
2019        .await;
2020
2021        let (_workspace, _thread_store, thread, context_store) =
2022            setup_test_environment(cx, project.clone()).await;
2023
2024        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2025            .await
2026            .unwrap();
2027
2028        let context =
2029            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2030
2031        // Insert user message with context
2032        let message_id = thread.update(cx, |thread, cx| {
2033            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2034        });
2035
2036        // Check content and context in message object
2037        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2038
2039        // Use different path format strings based on platform for the test
2040        #[cfg(windows)]
2041        let path_part = r"test\code.rs";
2042        #[cfg(not(windows))]
2043        let path_part = "test/code.rs";
2044
2045        let expected_context = format!(
2046            r#"
2047<context>
2048The following items were attached by the user. You don't need to use other tools to read them.
2049
2050<files>
2051```rs {path_part}
2052fn main() {{
2053    println!("Hello, world!");
2054}}
2055```
2056</files>
2057</context>
2058"#
2059        );
2060
2061        assert_eq!(message.role, Role::User);
2062        assert_eq!(message.segments.len(), 1);
2063        assert_eq!(
2064            message.segments[0],
2065            MessageSegment::Text("Please explain this code".to_string())
2066        );
2067        assert_eq!(message.context, expected_context);
2068
2069        // Check message in request
2070        let request = thread.read_with(cx, |thread, cx| {
2071            thread.to_completion_request(RequestKind::Chat, cx)
2072        });
2073
2074        assert_eq!(request.messages.len(), 2);
2075        let expected_full_message = format!("{}Please explain this code", expected_context);
2076        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2077    }
2078
2079    #[gpui::test]
2080    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2081        init_test_settings(cx);
2082
2083        let project = create_test_project(
2084            cx,
2085            json!({
2086                "file1.rs": "fn function1() {}\n",
2087                "file2.rs": "fn function2() {}\n",
2088                "file3.rs": "fn function3() {}\n",
2089            }),
2090        )
2091        .await;
2092
2093        let (_, _thread_store, thread, context_store) =
2094            setup_test_environment(cx, project.clone()).await;
2095
2096        // Open files individually
2097        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2098            .await
2099            .unwrap();
2100        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2101            .await
2102            .unwrap();
2103        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2104            .await
2105            .unwrap();
2106
2107        // Get the context objects
2108        let contexts = context_store.update(cx, |store, _| store.context().clone());
2109        assert_eq!(contexts.len(), 3);
2110
2111        // First message with context 1
2112        let message1_id = thread.update(cx, |thread, cx| {
2113            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2114        });
2115
2116        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2117        let message2_id = thread.update(cx, |thread, cx| {
2118            thread.insert_user_message(
2119                "Message 2",
2120                vec![contexts[0].clone(), contexts[1].clone()],
2121                None,
2122                cx,
2123            )
2124        });
2125
2126        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2127        let message3_id = thread.update(cx, |thread, cx| {
2128            thread.insert_user_message(
2129                "Message 3",
2130                vec![
2131                    contexts[0].clone(),
2132                    contexts[1].clone(),
2133                    contexts[2].clone(),
2134                ],
2135                None,
2136                cx,
2137            )
2138        });
2139
2140        // Check what contexts are included in each message
2141        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2142            (
2143                thread.message(message1_id).unwrap().clone(),
2144                thread.message(message2_id).unwrap().clone(),
2145                thread.message(message3_id).unwrap().clone(),
2146            )
2147        });
2148
2149        // First message should include context 1
2150        assert!(message1.context.contains("file1.rs"));
2151
2152        // Second message should include only context 2 (not 1)
2153        assert!(!message2.context.contains("file1.rs"));
2154        assert!(message2.context.contains("file2.rs"));
2155
2156        // Third message should include only context 3 (not 1 or 2)
2157        assert!(!message3.context.contains("file1.rs"));
2158        assert!(!message3.context.contains("file2.rs"));
2159        assert!(message3.context.contains("file3.rs"));
2160
2161        // Check entire request to make sure all contexts are properly included
2162        let request = thread.read_with(cx, |thread, cx| {
2163            thread.to_completion_request(RequestKind::Chat, cx)
2164        });
2165
2166        // The request should contain all 3 messages
2167        assert_eq!(request.messages.len(), 4);
2168
2169        // Check that the contexts are properly formatted in each message
2170        assert!(request.messages[1].string_contents().contains("file1.rs"));
2171        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2172        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2173
2174        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2175        assert!(request.messages[2].string_contents().contains("file2.rs"));
2176        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2177
2178        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2179        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2180        assert!(request.messages[3].string_contents().contains("file3.rs"));
2181    }
2182
2183    #[gpui::test]
2184    async fn test_message_without_files(cx: &mut TestAppContext) {
2185        init_test_settings(cx);
2186
2187        let project = create_test_project(
2188            cx,
2189            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2190        )
2191        .await;
2192
2193        let (_, _thread_store, thread, _context_store) =
2194            setup_test_environment(cx, project.clone()).await;
2195
2196        // Insert user message without any context (empty context vector)
2197        let message_id = thread.update(cx, |thread, cx| {
2198            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2199        });
2200
2201        // Check content and context in message object
2202        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2203
2204        // Context should be empty when no files are included
2205        assert_eq!(message.role, Role::User);
2206        assert_eq!(message.segments.len(), 1);
2207        assert_eq!(
2208            message.segments[0],
2209            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2210        );
2211        assert_eq!(message.context, "");
2212
2213        // Check message in request
2214        let request = thread.read_with(cx, |thread, cx| {
2215            thread.to_completion_request(RequestKind::Chat, cx)
2216        });
2217
2218        assert_eq!(request.messages.len(), 2);
2219        assert_eq!(
2220            request.messages[1].string_contents(),
2221            "What is the best way to learn Rust?"
2222        );
2223
2224        // Add second message, also without context
2225        let message2_id = thread.update(cx, |thread, cx| {
2226            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2227        });
2228
2229        let message2 =
2230            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2231        assert_eq!(message2.context, "");
2232
2233        // Check that both messages appear in the request
2234        let request = thread.read_with(cx, |thread, cx| {
2235            thread.to_completion_request(RequestKind::Chat, cx)
2236        });
2237
2238        assert_eq!(request.messages.len(), 3);
2239        assert_eq!(
2240            request.messages[1].string_contents(),
2241            "What is the best way to learn Rust?"
2242        );
2243        assert_eq!(
2244            request.messages[2].string_contents(),
2245            "Are there any good books?"
2246        );
2247    }
2248
2249    #[gpui::test]
2250    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2251        init_test_settings(cx);
2252
2253        let project = create_test_project(
2254            cx,
2255            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2256        )
2257        .await;
2258
2259        let (_workspace, _thread_store, thread, context_store) =
2260            setup_test_environment(cx, project.clone()).await;
2261
2262        // Open buffer and add it to context
2263        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2264            .await
2265            .unwrap();
2266
2267        let context =
2268            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2269
2270        // Insert user message with the buffer as context
2271        thread.update(cx, |thread, cx| {
2272            thread.insert_user_message("Explain this code", vec![context], None, cx)
2273        });
2274
2275        // Create a request and check that it doesn't have a stale buffer warning yet
2276        let initial_request = thread.read_with(cx, |thread, cx| {
2277            thread.to_completion_request(RequestKind::Chat, cx)
2278        });
2279
2280        // Make sure we don't have a stale file warning yet
2281        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2282            msg.string_contents()
2283                .contains("These files changed since last read:")
2284        });
2285        assert!(
2286            !has_stale_warning,
2287            "Should not have stale buffer warning before buffer is modified"
2288        );
2289
2290        // Modify the buffer
2291        buffer.update(cx, |buffer, cx| {
2292            // Find a position at the end of line 1
2293            buffer.edit(
2294                [(1..1, "\n    println!(\"Added a new line\");\n")],
2295                None,
2296                cx,
2297            );
2298        });
2299
2300        // Insert another user message without context
2301        thread.update(cx, |thread, cx| {
2302            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2303        });
2304
2305        // Create a new request and check for the stale buffer warning
2306        let new_request = thread.read_with(cx, |thread, cx| {
2307            thread.to_completion_request(RequestKind::Chat, cx)
2308        });
2309
2310        // We should have a stale file warning as the last message
2311        let last_message = new_request
2312            .messages
2313            .last()
2314            .expect("Request should have messages");
2315
2316        // The last message should be the stale buffer notification
2317        assert_eq!(last_message.role, Role::User);
2318
2319        // Check the exact content of the message
2320        let expected_content = "These files changed since last read:\n- code.rs\n";
2321        assert_eq!(
2322            last_message.string_contents(),
2323            expected_content,
2324            "Last message should be exactly the stale buffer notification"
2325        );
2326    }
2327
2328    fn init_test_settings(cx: &mut TestAppContext) {
2329        cx.update(|cx| {
2330            let settings_store = SettingsStore::test(cx);
2331            cx.set_global(settings_store);
2332            language::init(cx);
2333            Project::init_settings(cx);
2334            AssistantSettings::register(cx);
2335            thread_store::init(cx);
2336            workspace::init_settings(cx);
2337            ThemeSettings::register(cx);
2338            ContextServerSettings::register(cx);
2339            EditorSettings::register(cx);
2340        });
2341    }
2342
2343    // Helper to create a test project with test files
2344    async fn create_test_project(
2345        cx: &mut TestAppContext,
2346        files: serde_json::Value,
2347    ) -> Entity<Project> {
2348        let fs = FakeFs::new(cx.executor());
2349        fs.insert_tree(path!("/test"), files).await;
2350        Project::test(fs, [path!("/test").as_ref()], cx).await
2351    }
2352
2353    async fn setup_test_environment(
2354        cx: &mut TestAppContext,
2355        project: Entity<Project>,
2356    ) -> (
2357        Entity<Workspace>,
2358        Entity<ThreadStore>,
2359        Entity<Thread>,
2360        Entity<ContextStore>,
2361    ) {
2362        let (workspace, cx) =
2363            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2364
2365        let thread_store = cx
2366            .update(|_, cx| {
2367                ThreadStore::load(
2368                    project.clone(),
2369                    cx.new(|_| ToolWorkingSet::default()),
2370                    Arc::new(PromptBuilder::new(None).unwrap()),
2371                    cx,
2372                )
2373            })
2374            .await;
2375
2376        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2377        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2378
2379        (workspace, thread_store, thread, context_store)
2380    }
2381
2382    async fn add_file_to_context(
2383        project: &Entity<Project>,
2384        context_store: &Entity<ContextStore>,
2385        path: &str,
2386        cx: &mut TestAppContext,
2387    ) -> Result<Entity<language::Buffer>> {
2388        let buffer_path = project
2389            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2390            .unwrap();
2391
2392        let buffer = project
2393            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2394            .await
2395            .unwrap();
2396
2397        context_store
2398            .update(cx, |store, cx| {
2399                store.add_file_from_buffer(buffer.clone(), cx)
2400            })
2401            .await?;
2402
2403        Ok(buffer)
2404    }
2405}