thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  22    Role, StopReason, TokenUsage,
  23};
  24use project::Project;
  25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  26use prompt_store::PromptBuilder;
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use thiserror::Error;
  31use util::{ResultExt as _, TryFutureExt as _, post_inc};
  32use uuid::Uuid;
  33
  34use crate::context::{AssistantContext, ContextId, format_context_as_string};
  35use crate::thread_store::{
  36    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  37    SerializedToolUse, SharedProjectContext,
  38};
  39use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  40
  41#[derive(Debug, Clone, Copy)]
  42pub enum RequestKind {
  43    Chat,
  44    /// Used when summarizing a thread.
  45    Summarize,
  46}
  47
  48#[derive(
  49    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  50)]
  51pub struct ThreadId(Arc<str>);
  52
  53impl ThreadId {
  54    pub fn new() -> Self {
  55        Self(Uuid::new_v4().to_string().into())
  56    }
  57}
  58
  59impl std::fmt::Display for ThreadId {
  60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  61        write!(f, "{}", self.0)
  62    }
  63}
  64
  65impl From<&str> for ThreadId {
  66    fn from(value: &str) -> Self {
  67        Self(value.into())
  68    }
  69}
  70
  71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  72pub struct MessageId(pub(crate) usize);
  73
  74impl MessageId {
  75    fn post_inc(&mut self) -> Self {
  76        Self(post_inc(&mut self.0))
  77    }
  78}
  79
  80/// A message in a [`Thread`].
  81#[derive(Debug, Clone)]
  82pub struct Message {
  83    pub id: MessageId,
  84    pub role: Role,
  85    pub segments: Vec<MessageSegment>,
  86    pub context: String,
  87}
  88
  89impl Message {
  90    /// Returns whether the message contains any meaningful text that should be displayed
  91    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  92    pub fn should_display_content(&self) -> bool {
  93        self.segments.iter().all(|segment| segment.should_display())
  94    }
  95
  96    pub fn push_thinking(&mut self, text: &str) {
  97        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  98            segment.push_str(text);
  99        } else {
 100            self.segments
 101                .push(MessageSegment::Thinking(text.to_string()));
 102        }
 103    }
 104
 105    pub fn push_text(&mut self, text: &str) {
 106        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 107            segment.push_str(text);
 108        } else {
 109            self.segments.push(MessageSegment::Text(text.to_string()));
 110        }
 111    }
 112
 113    pub fn to_string(&self) -> String {
 114        let mut result = String::new();
 115
 116        if !self.context.is_empty() {
 117            result.push_str(&self.context);
 118        }
 119
 120        for segment in &self.segments {
 121            match segment {
 122                MessageSegment::Text(text) => result.push_str(text),
 123                MessageSegment::Thinking(text) => {
 124                    result.push_str("<think>");
 125                    result.push_str(text);
 126                    result.push_str("</think>");
 127                }
 128            }
 129        }
 130
 131        result
 132    }
 133}
 134
 135#[derive(Debug, Clone, PartialEq, Eq)]
 136pub enum MessageSegment {
 137    Text(String),
 138    Thinking(String),
 139}
 140
 141impl MessageSegment {
 142    pub fn text_mut(&mut self) -> &mut String {
 143        match self {
 144            Self::Text(text) => text,
 145            Self::Thinking(text) => text,
 146        }
 147    }
 148
 149    pub fn should_display(&self) -> bool {
 150        // We add USING_TOOL_MARKER when making a request that includes tool uses
 151        // without non-whitespace text around them, and this can cause the model
 152        // to mimic the pattern, so we consider those segments not displayable.
 153        match self {
 154            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 155            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 156        }
 157    }
 158}
 159
 160#[derive(Debug, Clone, Serialize, Deserialize)]
 161pub struct ProjectSnapshot {
 162    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 163    pub unsaved_buffer_paths: Vec<String>,
 164    pub timestamp: DateTime<Utc>,
 165}
 166
 167#[derive(Debug, Clone, Serialize, Deserialize)]
 168pub struct WorktreeSnapshot {
 169    pub worktree_path: String,
 170    pub git_state: Option<GitState>,
 171}
 172
 173#[derive(Debug, Clone, Serialize, Deserialize)]
 174pub struct GitState {
 175    pub remote_url: Option<String>,
 176    pub head_sha: Option<String>,
 177    pub current_branch: Option<String>,
 178    pub diff: Option<String>,
 179}
 180
 181#[derive(Clone)]
 182pub struct ThreadCheckpoint {
 183    message_id: MessageId,
 184    git_checkpoint: GitStoreCheckpoint,
 185}
 186
 187#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 188pub enum ThreadFeedback {
 189    Positive,
 190    Negative,
 191}
 192
 193pub enum LastRestoreCheckpoint {
 194    Pending {
 195        message_id: MessageId,
 196    },
 197    Error {
 198        message_id: MessageId,
 199        error: String,
 200    },
 201}
 202
 203impl LastRestoreCheckpoint {
 204    pub fn message_id(&self) -> MessageId {
 205        match self {
 206            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 207            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 208        }
 209    }
 210}
 211
 212#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 213pub enum DetailedSummaryState {
 214    #[default]
 215    NotGenerated,
 216    Generating {
 217        message_id: MessageId,
 218    },
 219    Generated {
 220        text: SharedString,
 221        message_id: MessageId,
 222    },
 223}
 224
 225#[derive(Default)]
 226pub struct TotalTokenUsage {
 227    pub total: usize,
 228    pub max: usize,
 229    pub ratio: TokenUsageRatio,
 230}
 231
 232#[derive(Debug, Default, PartialEq, Eq)]
 233pub enum TokenUsageRatio {
 234    #[default]
 235    Normal,
 236    Warning,
 237    Exceeded,
 238}
 239
 240/// A thread of conversation with the LLM.
 241pub struct Thread {
 242    id: ThreadId,
 243    updated_at: DateTime<Utc>,
 244    summary: Option<SharedString>,
 245    pending_summary: Task<Option<()>>,
 246    detailed_summary_state: DetailedSummaryState,
 247    messages: Vec<Message>,
 248    next_message_id: MessageId,
 249    context: BTreeMap<ContextId, AssistantContext>,
 250    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 251    project_context: SharedProjectContext,
 252    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 253    completion_count: usize,
 254    pending_completions: Vec<PendingCompletion>,
 255    project: Entity<Project>,
 256    prompt_builder: Arc<PromptBuilder>,
 257    tools: Arc<ToolWorkingSet>,
 258    tool_use: ToolUseState,
 259    action_log: Entity<ActionLog>,
 260    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 261    pending_checkpoint: Option<ThreadCheckpoint>,
 262    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 263    cumulative_token_usage: TokenUsage,
 264    exceeded_window_error: Option<ExceededWindowError>,
 265    feedback: Option<ThreadFeedback>,
 266    message_feedback: HashMap<MessageId, ThreadFeedback>,
 267    last_auto_capture_at: Option<Instant>,
 268}
 269
 270#[derive(Debug, Clone, Serialize, Deserialize)]
 271pub struct ExceededWindowError {
 272    /// Model used when last message exceeded context window
 273    model_id: LanguageModelId,
 274    /// Token count including last message
 275    token_count: usize,
 276}
 277
 278impl Thread {
 279    pub fn new(
 280        project: Entity<Project>,
 281        tools: Arc<ToolWorkingSet>,
 282        prompt_builder: Arc<PromptBuilder>,
 283        system_prompt: SharedProjectContext,
 284        cx: &mut Context<Self>,
 285    ) -> Self {
 286        Self {
 287            id: ThreadId::new(),
 288            updated_at: Utc::now(),
 289            summary: None,
 290            pending_summary: Task::ready(None),
 291            detailed_summary_state: DetailedSummaryState::NotGenerated,
 292            messages: Vec::new(),
 293            next_message_id: MessageId(0),
 294            context: BTreeMap::default(),
 295            context_by_message: HashMap::default(),
 296            project_context: system_prompt,
 297            checkpoints_by_message: HashMap::default(),
 298            completion_count: 0,
 299            pending_completions: Vec::new(),
 300            project: project.clone(),
 301            prompt_builder,
 302            tools: tools.clone(),
 303            last_restore_checkpoint: None,
 304            pending_checkpoint: None,
 305            tool_use: ToolUseState::new(tools.clone()),
 306            action_log: cx.new(|_| ActionLog::new(project.clone())),
 307            initial_project_snapshot: {
 308                let project_snapshot = Self::project_snapshot(project, cx);
 309                cx.foreground_executor()
 310                    .spawn(async move { Some(project_snapshot.await) })
 311                    .shared()
 312            },
 313            cumulative_token_usage: TokenUsage::default(),
 314            exceeded_window_error: None,
 315            feedback: None,
 316            message_feedback: HashMap::default(),
 317            last_auto_capture_at: None,
 318        }
 319    }
 320
 321    pub fn deserialize(
 322        id: ThreadId,
 323        serialized: SerializedThread,
 324        project: Entity<Project>,
 325        tools: Arc<ToolWorkingSet>,
 326        prompt_builder: Arc<PromptBuilder>,
 327        project_context: SharedProjectContext,
 328        cx: &mut Context<Self>,
 329    ) -> Self {
 330        let next_message_id = MessageId(
 331            serialized
 332                .messages
 333                .last()
 334                .map(|message| message.id.0 + 1)
 335                .unwrap_or(0),
 336        );
 337        let tool_use =
 338            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 339
 340        Self {
 341            id,
 342            updated_at: serialized.updated_at,
 343            summary: Some(serialized.summary),
 344            pending_summary: Task::ready(None),
 345            detailed_summary_state: serialized.detailed_summary_state,
 346            messages: serialized
 347                .messages
 348                .into_iter()
 349                .map(|message| Message {
 350                    id: message.id,
 351                    role: message.role,
 352                    segments: message
 353                        .segments
 354                        .into_iter()
 355                        .map(|segment| match segment {
 356                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 357                            SerializedMessageSegment::Thinking { text } => {
 358                                MessageSegment::Thinking(text)
 359                            }
 360                        })
 361                        .collect(),
 362                    context: message.context,
 363                })
 364                .collect(),
 365            next_message_id,
 366            context: BTreeMap::default(),
 367            context_by_message: HashMap::default(),
 368            project_context,
 369            checkpoints_by_message: HashMap::default(),
 370            completion_count: 0,
 371            pending_completions: Vec::new(),
 372            last_restore_checkpoint: None,
 373            pending_checkpoint: None,
 374            project: project.clone(),
 375            prompt_builder,
 376            tools,
 377            tool_use,
 378            action_log: cx.new(|_| ActionLog::new(project)),
 379            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 380            cumulative_token_usage: serialized.cumulative_token_usage,
 381            exceeded_window_error: None,
 382            feedback: None,
 383            message_feedback: HashMap::default(),
 384            last_auto_capture_at: None,
 385        }
 386    }
 387
 388    pub fn id(&self) -> &ThreadId {
 389        &self.id
 390    }
 391
 392    pub fn is_empty(&self) -> bool {
 393        self.messages.is_empty()
 394    }
 395
 396    pub fn updated_at(&self) -> DateTime<Utc> {
 397        self.updated_at
 398    }
 399
 400    pub fn touch_updated_at(&mut self) {
 401        self.updated_at = Utc::now();
 402    }
 403
 404    pub fn summary(&self) -> Option<SharedString> {
 405        self.summary.clone()
 406    }
 407
 408    pub fn project_context(&self) -> SharedProjectContext {
 409        self.project_context.clone()
 410    }
 411
 412    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 413
 414    pub fn summary_or_default(&self) -> SharedString {
 415        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 416    }
 417
 418    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 419        let Some(current_summary) = &self.summary else {
 420            // Don't allow setting summary until generated
 421            return;
 422        };
 423
 424        let mut new_summary = new_summary.into();
 425
 426        if new_summary.is_empty() {
 427            new_summary = Self::DEFAULT_SUMMARY;
 428        }
 429
 430        if current_summary != &new_summary {
 431            self.summary = Some(new_summary);
 432            cx.emit(ThreadEvent::SummaryChanged);
 433        }
 434    }
 435
 436    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 437        self.latest_detailed_summary()
 438            .unwrap_or_else(|| self.text().into())
 439    }
 440
 441    fn latest_detailed_summary(&self) -> Option<SharedString> {
 442        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 443            Some(text.clone())
 444        } else {
 445            None
 446        }
 447    }
 448
 449    pub fn message(&self, id: MessageId) -> Option<&Message> {
 450        self.messages.iter().find(|message| message.id == id)
 451    }
 452
 453    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 454        self.messages.iter()
 455    }
 456
 457    pub fn is_generating(&self) -> bool {
 458        !self.pending_completions.is_empty() || !self.all_tools_finished()
 459    }
 460
 461    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 462        &self.tools
 463    }
 464
 465    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 466        self.tool_use
 467            .pending_tool_uses()
 468            .into_iter()
 469            .find(|tool_use| &tool_use.id == id)
 470    }
 471
 472    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 473        self.tool_use
 474            .pending_tool_uses()
 475            .into_iter()
 476            .filter(|tool_use| tool_use.status.needs_confirmation())
 477    }
 478
 479    pub fn has_pending_tool_uses(&self) -> bool {
 480        !self.tool_use.pending_tool_uses().is_empty()
 481    }
 482
 483    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 484        self.checkpoints_by_message.get(&id).cloned()
 485    }
 486
 487    pub fn restore_checkpoint(
 488        &mut self,
 489        checkpoint: ThreadCheckpoint,
 490        cx: &mut Context<Self>,
 491    ) -> Task<Result<()>> {
 492        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 493            message_id: checkpoint.message_id,
 494        });
 495        cx.emit(ThreadEvent::CheckpointChanged);
 496        cx.notify();
 497
 498        let git_store = self.project().read(cx).git_store().clone();
 499        let restore = git_store.update(cx, |git_store, cx| {
 500            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 501        });
 502
 503        cx.spawn(async move |this, cx| {
 504            let result = restore.await;
 505            this.update(cx, |this, cx| {
 506                if let Err(err) = result.as_ref() {
 507                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 508                        message_id: checkpoint.message_id,
 509                        error: err.to_string(),
 510                    });
 511                } else {
 512                    this.truncate(checkpoint.message_id, cx);
 513                    this.last_restore_checkpoint = None;
 514                }
 515                this.pending_checkpoint = None;
 516                cx.emit(ThreadEvent::CheckpointChanged);
 517                cx.notify();
 518            })?;
 519            result
 520        })
 521    }
 522
 523    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 524        let pending_checkpoint = if self.is_generating() {
 525            return;
 526        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 527            checkpoint
 528        } else {
 529            return;
 530        };
 531
 532        let git_store = self.project.read(cx).git_store().clone();
 533        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 534        cx.spawn(async move |this, cx| match final_checkpoint.await {
 535            Ok(final_checkpoint) => {
 536                let equal = git_store
 537                    .update(cx, |store, cx| {
 538                        store.compare_checkpoints(
 539                            pending_checkpoint.git_checkpoint.clone(),
 540                            final_checkpoint.clone(),
 541                            cx,
 542                        )
 543                    })?
 544                    .await
 545                    .unwrap_or(false);
 546
 547                if equal {
 548                    git_store
 549                        .update(cx, |store, cx| {
 550                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 551                        })?
 552                        .detach();
 553                } else {
 554                    this.update(cx, |this, cx| {
 555                        this.insert_checkpoint(pending_checkpoint, cx)
 556                    })?;
 557                }
 558
 559                git_store
 560                    .update(cx, |store, cx| {
 561                        store.delete_checkpoint(final_checkpoint, cx)
 562                    })?
 563                    .detach();
 564
 565                Ok(())
 566            }
 567            Err(_) => this.update(cx, |this, cx| {
 568                this.insert_checkpoint(pending_checkpoint, cx)
 569            }),
 570        })
 571        .detach();
 572    }
 573
 574    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 575        self.checkpoints_by_message
 576            .insert(checkpoint.message_id, checkpoint);
 577        cx.emit(ThreadEvent::CheckpointChanged);
 578        cx.notify();
 579    }
 580
 581    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 582        self.last_restore_checkpoint.as_ref()
 583    }
 584
 585    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 586        let Some(message_ix) = self
 587            .messages
 588            .iter()
 589            .rposition(|message| message.id == message_id)
 590        else {
 591            return;
 592        };
 593        for deleted_message in self.messages.drain(message_ix..) {
 594            self.context_by_message.remove(&deleted_message.id);
 595            self.checkpoints_by_message.remove(&deleted_message.id);
 596        }
 597        cx.notify();
 598    }
 599
 600    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 601        self.context_by_message
 602            .get(&id)
 603            .into_iter()
 604            .flat_map(|context| {
 605                context
 606                    .iter()
 607                    .filter_map(|context_id| self.context.get(&context_id))
 608            })
 609    }
 610
 611    /// Returns whether all of the tool uses have finished running.
 612    pub fn all_tools_finished(&self) -> bool {
 613        // If the only pending tool uses left are the ones with errors, then
 614        // that means that we've finished running all of the pending tools.
 615        self.tool_use
 616            .pending_tool_uses()
 617            .iter()
 618            .all(|tool_use| tool_use.status.is_error())
 619    }
 620
 621    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 622        self.tool_use.tool_uses_for_message(id, cx)
 623    }
 624
 625    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 626        self.tool_use.tool_results_for_message(id)
 627    }
 628
 629    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 630        self.tool_use.tool_result(id)
 631    }
 632
 633    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 634        self.tool_use.message_has_tool_results(message_id)
 635    }
 636
 637    pub fn insert_user_message(
 638        &mut self,
 639        text: impl Into<String>,
 640        context: Vec<AssistantContext>,
 641        git_checkpoint: Option<GitStoreCheckpoint>,
 642        cx: &mut Context<Self>,
 643    ) -> MessageId {
 644        let text = text.into();
 645
 646        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 647
 648        // Filter out contexts that have already been included in previous messages
 649        let new_context: Vec<_> = context
 650            .into_iter()
 651            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 652            .collect();
 653
 654        if !new_context.is_empty() {
 655            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 656                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 657                    message.context = context_string;
 658                }
 659            }
 660
 661            self.action_log.update(cx, |log, cx| {
 662                // Track all buffers added as context
 663                for ctx in &new_context {
 664                    match ctx {
 665                        AssistantContext::File(file_ctx) => {
 666                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 667                        }
 668                        AssistantContext::Directory(dir_ctx) => {
 669                            for context_buffer in &dir_ctx.context_buffers {
 670                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 671                            }
 672                        }
 673                        AssistantContext::Symbol(symbol_ctx) => {
 674                            log.buffer_added_as_context(
 675                                symbol_ctx.context_symbol.buffer.clone(),
 676                                cx,
 677                            );
 678                        }
 679                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 680                    }
 681                }
 682            });
 683        }
 684
 685        let context_ids = new_context
 686            .iter()
 687            .map(|context| context.id())
 688            .collect::<Vec<_>>();
 689        self.context.extend(
 690            new_context
 691                .into_iter()
 692                .map(|context| (context.id(), context)),
 693        );
 694        self.context_by_message.insert(message_id, context_ids);
 695
 696        if let Some(git_checkpoint) = git_checkpoint {
 697            self.pending_checkpoint = Some(ThreadCheckpoint {
 698                message_id,
 699                git_checkpoint,
 700            });
 701        }
 702
 703        self.auto_capture_telemetry(cx);
 704
 705        message_id
 706    }
 707
 708    pub fn insert_message(
 709        &mut self,
 710        role: Role,
 711        segments: Vec<MessageSegment>,
 712        cx: &mut Context<Self>,
 713    ) -> MessageId {
 714        let id = self.next_message_id.post_inc();
 715        self.messages.push(Message {
 716            id,
 717            role,
 718            segments,
 719            context: String::new(),
 720        });
 721        self.touch_updated_at();
 722        cx.emit(ThreadEvent::MessageAdded(id));
 723        id
 724    }
 725
 726    pub fn edit_message(
 727        &mut self,
 728        id: MessageId,
 729        new_role: Role,
 730        new_segments: Vec<MessageSegment>,
 731        cx: &mut Context<Self>,
 732    ) -> bool {
 733        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 734            return false;
 735        };
 736        message.role = new_role;
 737        message.segments = new_segments;
 738        self.touch_updated_at();
 739        cx.emit(ThreadEvent::MessageEdited(id));
 740        true
 741    }
 742
 743    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 744        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 745            return false;
 746        };
 747        self.messages.remove(index);
 748        self.context_by_message.remove(&id);
 749        self.touch_updated_at();
 750        cx.emit(ThreadEvent::MessageDeleted(id));
 751        true
 752    }
 753
 754    /// Returns the representation of this [`Thread`] in a textual form.
 755    ///
 756    /// This is the representation we use when attaching a thread as context to another thread.
 757    pub fn text(&self) -> String {
 758        let mut text = String::new();
 759
 760        for message in &self.messages {
 761            text.push_str(match message.role {
 762                language_model::Role::User => "User:",
 763                language_model::Role::Assistant => "Assistant:",
 764                language_model::Role::System => "System:",
 765            });
 766            text.push('\n');
 767
 768            for segment in &message.segments {
 769                match segment {
 770                    MessageSegment::Text(content) => text.push_str(content),
 771                    MessageSegment::Thinking(content) => {
 772                        text.push_str(&format!("<think>{}</think>", content))
 773                    }
 774                }
 775            }
 776            text.push('\n');
 777        }
 778
 779        text
 780    }
 781
 782    /// Serializes this thread into a format for storage or telemetry.
 783    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 784        let initial_project_snapshot = self.initial_project_snapshot.clone();
 785        cx.spawn(async move |this, cx| {
 786            let initial_project_snapshot = initial_project_snapshot.await;
 787            this.read_with(cx, |this, cx| SerializedThread {
 788                version: SerializedThread::VERSION.to_string(),
 789                summary: this.summary_or_default(),
 790                updated_at: this.updated_at(),
 791                messages: this
 792                    .messages()
 793                    .map(|message| SerializedMessage {
 794                        id: message.id,
 795                        role: message.role,
 796                        segments: message
 797                            .segments
 798                            .iter()
 799                            .map(|segment| match segment {
 800                                MessageSegment::Text(text) => {
 801                                    SerializedMessageSegment::Text { text: text.clone() }
 802                                }
 803                                MessageSegment::Thinking(text) => {
 804                                    SerializedMessageSegment::Thinking { text: text.clone() }
 805                                }
 806                            })
 807                            .collect(),
 808                        tool_uses: this
 809                            .tool_uses_for_message(message.id, cx)
 810                            .into_iter()
 811                            .map(|tool_use| SerializedToolUse {
 812                                id: tool_use.id,
 813                                name: tool_use.name,
 814                                input: tool_use.input,
 815                            })
 816                            .collect(),
 817                        tool_results: this
 818                            .tool_results_for_message(message.id)
 819                            .into_iter()
 820                            .map(|tool_result| SerializedToolResult {
 821                                tool_use_id: tool_result.tool_use_id.clone(),
 822                                is_error: tool_result.is_error,
 823                                content: tool_result.content.clone(),
 824                            })
 825                            .collect(),
 826                        context: message.context.clone(),
 827                    })
 828                    .collect(),
 829                initial_project_snapshot,
 830                cumulative_token_usage: this.cumulative_token_usage,
 831                detailed_summary_state: this.detailed_summary_state.clone(),
 832                exceeded_window_error: this.exceeded_window_error.clone(),
 833            })
 834        })
 835    }
 836
 837    pub fn send_to_model(
 838        &mut self,
 839        model: Arc<dyn LanguageModel>,
 840        request_kind: RequestKind,
 841        cx: &mut Context<Self>,
 842    ) {
 843        let mut request = self.to_completion_request(request_kind, cx);
 844        if model.supports_tools() {
 845            request.tools = {
 846                let mut tools = Vec::new();
 847                tools.extend(
 848                    self.tools()
 849                        .enabled_tools(cx)
 850                        .into_iter()
 851                        .filter_map(|tool| {
 852                            // Skip tools that cannot be supported
 853                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 854                            Some(LanguageModelRequestTool {
 855                                name: tool.name(),
 856                                description: tool.description(),
 857                                input_schema,
 858                            })
 859                        }),
 860                );
 861
 862                tools
 863            };
 864        }
 865
 866        self.stream_completion(request, model, cx);
 867    }
 868
 869    pub fn used_tools_since_last_user_message(&self) -> bool {
 870        for message in self.messages.iter().rev() {
 871            if self.tool_use.message_has_tool_results(message.id) {
 872                return true;
 873            } else if message.role == Role::User {
 874                return false;
 875            }
 876        }
 877
 878        false
 879    }
 880
 881    pub fn to_completion_request(
 882        &self,
 883        request_kind: RequestKind,
 884        cx: &App,
 885    ) -> LanguageModelRequest {
 886        let mut request = LanguageModelRequest {
 887            messages: vec![],
 888            tools: Vec::new(),
 889            stop: Vec::new(),
 890            temperature: None,
 891        };
 892
 893        if let Some(project_context) = self.project_context.borrow().as_ref() {
 894            if let Some(system_prompt) = self
 895                .prompt_builder
 896                .generate_assistant_system_prompt(project_context)
 897                .context("failed to generate assistant system prompt")
 898                .log_err()
 899            {
 900                request.messages.push(LanguageModelRequestMessage {
 901                    role: Role::System,
 902                    content: vec![MessageContent::Text(system_prompt)],
 903                    cache: true,
 904                });
 905            }
 906        } else {
 907            log::error!("project_context not set.")
 908        }
 909
 910        for message in &self.messages {
 911            let mut request_message = LanguageModelRequestMessage {
 912                role: message.role,
 913                content: Vec::new(),
 914                cache: false,
 915            };
 916
 917            match request_kind {
 918                RequestKind::Chat => {
 919                    self.tool_use
 920                        .attach_tool_results(message.id, &mut request_message);
 921                }
 922                RequestKind::Summarize => {
 923                    // We don't care about tool use during summarization.
 924                    if self.tool_use.message_has_tool_results(message.id) {
 925                        continue;
 926                    }
 927                }
 928            }
 929
 930            if !message.segments.is_empty() {
 931                request_message
 932                    .content
 933                    .push(MessageContent::Text(message.to_string()));
 934            }
 935
 936            match request_kind {
 937                RequestKind::Chat => {
 938                    self.tool_use
 939                        .attach_tool_uses(message.id, &mut request_message);
 940                }
 941                RequestKind::Summarize => {
 942                    // We don't care about tool use during summarization.
 943                }
 944            };
 945
 946            request.messages.push(request_message);
 947        }
 948
 949        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 950        if let Some(last) = request.messages.last_mut() {
 951            last.cache = true;
 952        }
 953
 954        self.attached_tracked_files_state(&mut request.messages, cx);
 955
 956        request
 957    }
 958
 959    fn attached_tracked_files_state(
 960        &self,
 961        messages: &mut Vec<LanguageModelRequestMessage>,
 962        cx: &App,
 963    ) {
 964        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 965
 966        let mut stale_message = String::new();
 967
 968        let action_log = self.action_log.read(cx);
 969
 970        for stale_file in action_log.stale_buffers(cx) {
 971            let Some(file) = stale_file.read(cx).file() else {
 972                continue;
 973            };
 974
 975            if stale_message.is_empty() {
 976                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
 977            }
 978
 979            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 980        }
 981
 982        let mut content = Vec::with_capacity(2);
 983
 984        if !stale_message.is_empty() {
 985            content.push(stale_message.into());
 986        }
 987
 988        if action_log.has_edited_files_since_project_diagnostics_check() {
 989            content.push(
 990                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 991                and fix all errors AND warnings you introduced! \
 992                DO NOT mention you're going to do this until you're done."
 993                    .into(),
 994            );
 995        }
 996
 997        if !content.is_empty() {
 998            let context_message = LanguageModelRequestMessage {
 999                role: Role::User,
1000                content,
1001                cache: false,
1002            };
1003
1004            messages.push(context_message);
1005        }
1006    }
1007
1008    pub fn stream_completion(
1009        &mut self,
1010        request: LanguageModelRequest,
1011        model: Arc<dyn LanguageModel>,
1012        cx: &mut Context<Self>,
1013    ) {
1014        let pending_completion_id = post_inc(&mut self.completion_count);
1015
1016        let task = cx.spawn(async move |thread, cx| {
1017            let stream = model.stream_completion(request, &cx);
1018            let initial_token_usage =
1019                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1020            let stream_completion = async {
1021                let mut events = stream.await?;
1022                let mut stop_reason = StopReason::EndTurn;
1023                let mut current_token_usage = TokenUsage::default();
1024
1025                while let Some(event) = events.next().await {
1026                    let event = event?;
1027
1028                    thread.update(cx, |thread, cx| {
1029                        match event {
1030                            LanguageModelCompletionEvent::StartMessage { .. } => {
1031                                thread.insert_message(
1032                                    Role::Assistant,
1033                                    vec![MessageSegment::Text(String::new())],
1034                                    cx,
1035                                );
1036                            }
1037                            LanguageModelCompletionEvent::Stop(reason) => {
1038                                stop_reason = reason;
1039                            }
1040                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1041                                thread.cumulative_token_usage = thread.cumulative_token_usage
1042                                    + token_usage
1043                                    - current_token_usage;
1044                                current_token_usage = token_usage;
1045                            }
1046                            LanguageModelCompletionEvent::Text(chunk) => {
1047                                if let Some(last_message) = thread.messages.last_mut() {
1048                                    if last_message.role == Role::Assistant {
1049                                        last_message.push_text(&chunk);
1050                                        cx.emit(ThreadEvent::StreamedAssistantText(
1051                                            last_message.id,
1052                                            chunk,
1053                                        ));
1054                                    } else {
1055                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1056                                        // of a new Assistant response.
1057                                        //
1058                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1059                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1060                                        thread.insert_message(
1061                                            Role::Assistant,
1062                                            vec![MessageSegment::Text(chunk.to_string())],
1063                                            cx,
1064                                        );
1065                                    };
1066                                }
1067                            }
1068                            LanguageModelCompletionEvent::Thinking(chunk) => {
1069                                if let Some(last_message) = thread.messages.last_mut() {
1070                                    if last_message.role == Role::Assistant {
1071                                        last_message.push_thinking(&chunk);
1072                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1073                                            last_message.id,
1074                                            chunk,
1075                                        ));
1076                                    } else {
1077                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1078                                        // of a new Assistant response.
1079                                        //
1080                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1081                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1082                                        thread.insert_message(
1083                                            Role::Assistant,
1084                                            vec![MessageSegment::Thinking(chunk.to_string())],
1085                                            cx,
1086                                        );
1087                                    };
1088                                }
1089                            }
1090                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1091                                let last_assistant_message_id = thread
1092                                    .messages
1093                                    .iter_mut()
1094                                    .rfind(|message| message.role == Role::Assistant)
1095                                    .map(|message| message.id)
1096                                    .unwrap_or_else(|| {
1097                                        thread.insert_message(Role::Assistant, vec![], cx)
1098                                    });
1099
1100                                thread.tool_use.request_tool_use(
1101                                    last_assistant_message_id,
1102                                    tool_use,
1103                                    cx,
1104                                );
1105                            }
1106                        }
1107
1108                        thread.touch_updated_at();
1109                        cx.emit(ThreadEvent::StreamedCompletion);
1110                        cx.notify();
1111
1112                        thread.auto_capture_telemetry(cx);
1113                    })?;
1114
1115                    smol::future::yield_now().await;
1116                }
1117
1118                thread.update(cx, |thread, cx| {
1119                    thread
1120                        .pending_completions
1121                        .retain(|completion| completion.id != pending_completion_id);
1122
1123                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1124                        thread.summarize(cx);
1125                    }
1126                })?;
1127
1128                anyhow::Ok(stop_reason)
1129            };
1130
1131            let result = stream_completion.await;
1132
1133            thread
1134                .update(cx, |thread, cx| {
1135                    thread.finalize_pending_checkpoint(cx);
1136                    match result.as_ref() {
1137                        Ok(stop_reason) => match stop_reason {
1138                            StopReason::ToolUse => {
1139                                let tool_uses = thread.use_pending_tools(cx);
1140                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1141                            }
1142                            StopReason::EndTurn => {}
1143                            StopReason::MaxTokens => {}
1144                        },
1145                        Err(error) => {
1146                            if error.is::<PaymentRequiredError>() {
1147                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1148                            } else if error.is::<MaxMonthlySpendReachedError>() {
1149                                cx.emit(ThreadEvent::ShowError(
1150                                    ThreadError::MaxMonthlySpendReached,
1151                                ));
1152                            } else if let Some(known_error) =
1153                                error.downcast_ref::<LanguageModelKnownError>()
1154                            {
1155                                match known_error {
1156                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1157                                        tokens,
1158                                    } => {
1159                                        thread.exceeded_window_error = Some(ExceededWindowError {
1160                                            model_id: model.id(),
1161                                            token_count: *tokens,
1162                                        });
1163                                        cx.notify();
1164                                    }
1165                                }
1166                            } else {
1167                                let error_message = error
1168                                    .chain()
1169                                    .map(|err| err.to_string())
1170                                    .collect::<Vec<_>>()
1171                                    .join("\n");
1172                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1173                                    header: "Error interacting with language model".into(),
1174                                    message: SharedString::from(error_message.clone()),
1175                                }));
1176                            }
1177
1178                            thread.cancel_last_completion(cx);
1179                        }
1180                    }
1181                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1182
1183                    thread.auto_capture_telemetry(cx);
1184
1185                    if let Ok(initial_usage) = initial_token_usage {
1186                        let usage = thread.cumulative_token_usage - initial_usage;
1187
1188                        telemetry::event!(
1189                            "Assistant Thread Completion",
1190                            thread_id = thread.id().to_string(),
1191                            model = model.telemetry_id(),
1192                            model_provider = model.provider_id().to_string(),
1193                            input_tokens = usage.input_tokens,
1194                            output_tokens = usage.output_tokens,
1195                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1196                            cache_read_input_tokens = usage.cache_read_input_tokens,
1197                        );
1198                    }
1199                })
1200                .ok();
1201        });
1202
1203        self.pending_completions.push(PendingCompletion {
1204            id: pending_completion_id,
1205            _task: task,
1206        });
1207    }
1208
1209    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1210        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1211            return;
1212        };
1213
1214        if !model.provider.is_authenticated(cx) {
1215            return;
1216        }
1217
1218        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1219        request.messages.push(LanguageModelRequestMessage {
1220            role: Role::User,
1221            content: vec![
1222                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1223                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1224                 If the conversation is about a specific subject, include it in the title. \
1225                 Be descriptive. DO NOT speak in the first person."
1226                    .into(),
1227            ],
1228            cache: false,
1229        });
1230
1231        self.pending_summary = cx.spawn(async move |this, cx| {
1232            async move {
1233                let stream = model.model.stream_completion_text(request, &cx);
1234                let mut messages = stream.await?;
1235
1236                let mut new_summary = String::new();
1237                while let Some(message) = messages.stream.next().await {
1238                    let text = message?;
1239                    let mut lines = text.lines();
1240                    new_summary.extend(lines.next());
1241
1242                    // Stop if the LLM generated multiple lines.
1243                    if lines.next().is_some() {
1244                        break;
1245                    }
1246                }
1247
1248                this.update(cx, |this, cx| {
1249                    if !new_summary.is_empty() {
1250                        this.summary = Some(new_summary.into());
1251                    }
1252
1253                    cx.emit(ThreadEvent::SummaryGenerated);
1254                })?;
1255
1256                anyhow::Ok(())
1257            }
1258            .log_err()
1259            .await
1260        });
1261    }
1262
1263    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1264        let last_message_id = self.messages.last().map(|message| message.id)?;
1265
1266        match &self.detailed_summary_state {
1267            DetailedSummaryState::Generating { message_id, .. }
1268            | DetailedSummaryState::Generated { message_id, .. }
1269                if *message_id == last_message_id =>
1270            {
1271                // Already up-to-date
1272                return None;
1273            }
1274            _ => {}
1275        }
1276
1277        let ConfiguredModel { model, provider } =
1278            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1279
1280        if !provider.is_authenticated(cx) {
1281            return None;
1282        }
1283
1284        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1285
1286        request.messages.push(LanguageModelRequestMessage {
1287            role: Role::User,
1288            content: vec![
1289                "Generate a detailed summary of this conversation. Include:\n\
1290                1. A brief overview of what was discussed\n\
1291                2. Key facts or information discovered\n\
1292                3. Outcomes or conclusions reached\n\
1293                4. Any action items or next steps if any\n\
1294                Format it in Markdown with headings and bullet points."
1295                    .into(),
1296            ],
1297            cache: false,
1298        });
1299
1300        let task = cx.spawn(async move |thread, cx| {
1301            let stream = model.stream_completion_text(request, &cx);
1302            let Some(mut messages) = stream.await.log_err() else {
1303                thread
1304                    .update(cx, |this, _cx| {
1305                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1306                    })
1307                    .log_err();
1308
1309                return;
1310            };
1311
1312            let mut new_detailed_summary = String::new();
1313
1314            while let Some(chunk) = messages.stream.next().await {
1315                if let Some(chunk) = chunk.log_err() {
1316                    new_detailed_summary.push_str(&chunk);
1317                }
1318            }
1319
1320            thread
1321                .update(cx, |this, _cx| {
1322                    this.detailed_summary_state = DetailedSummaryState::Generated {
1323                        text: new_detailed_summary.into(),
1324                        message_id: last_message_id,
1325                    };
1326                })
1327                .log_err();
1328        });
1329
1330        self.detailed_summary_state = DetailedSummaryState::Generating {
1331            message_id: last_message_id,
1332        };
1333
1334        Some(task)
1335    }
1336
1337    pub fn is_generating_detailed_summary(&self) -> bool {
1338        matches!(
1339            self.detailed_summary_state,
1340            DetailedSummaryState::Generating { .. }
1341        )
1342    }
1343
1344    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1345        self.auto_capture_telemetry(cx);
1346        let request = self.to_completion_request(RequestKind::Chat, cx);
1347        let messages = Arc::new(request.messages);
1348        let pending_tool_uses = self
1349            .tool_use
1350            .pending_tool_uses()
1351            .into_iter()
1352            .filter(|tool_use| tool_use.status.is_idle())
1353            .cloned()
1354            .collect::<Vec<_>>();
1355
1356        for tool_use in pending_tool_uses.iter() {
1357            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1358                if tool.needs_confirmation(&tool_use.input, cx)
1359                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1360                {
1361                    self.tool_use.confirm_tool_use(
1362                        tool_use.id.clone(),
1363                        tool_use.ui_text.clone(),
1364                        tool_use.input.clone(),
1365                        messages.clone(),
1366                        tool,
1367                    );
1368                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1369                } else {
1370                    self.run_tool(
1371                        tool_use.id.clone(),
1372                        tool_use.ui_text.clone(),
1373                        tool_use.input.clone(),
1374                        &messages,
1375                        tool,
1376                        cx,
1377                    );
1378                }
1379            }
1380        }
1381
1382        pending_tool_uses
1383    }
1384
1385    pub fn run_tool(
1386        &mut self,
1387        tool_use_id: LanguageModelToolUseId,
1388        ui_text: impl Into<SharedString>,
1389        input: serde_json::Value,
1390        messages: &[LanguageModelRequestMessage],
1391        tool: Arc<dyn Tool>,
1392        cx: &mut Context<Thread>,
1393    ) {
1394        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1395        self.tool_use
1396            .run_pending_tool(tool_use_id, ui_text.into(), task);
1397    }
1398
1399    fn spawn_tool_use(
1400        &mut self,
1401        tool_use_id: LanguageModelToolUseId,
1402        messages: &[LanguageModelRequestMessage],
1403        input: serde_json::Value,
1404        tool: Arc<dyn Tool>,
1405        cx: &mut Context<Thread>,
1406    ) -> Task<()> {
1407        let tool_name: Arc<str> = tool.name().into();
1408
1409        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1410            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1411        } else {
1412            tool.run(
1413                input,
1414                messages,
1415                self.project.clone(),
1416                self.action_log.clone(),
1417                cx,
1418            )
1419        };
1420
1421        cx.spawn({
1422            async move |thread: WeakEntity<Thread>, cx| {
1423                let output = run_tool.await;
1424
1425                thread
1426                    .update(cx, |thread, cx| {
1427                        let pending_tool_use = thread.tool_use.insert_tool_output(
1428                            tool_use_id.clone(),
1429                            tool_name,
1430                            output,
1431                            cx,
1432                        );
1433                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1434                    })
1435                    .ok();
1436            }
1437        })
1438    }
1439
1440    fn tool_finished(
1441        &mut self,
1442        tool_use_id: LanguageModelToolUseId,
1443        pending_tool_use: Option<PendingToolUse>,
1444        canceled: bool,
1445        cx: &mut Context<Self>,
1446    ) {
1447        if self.all_tools_finished() {
1448            let model_registry = LanguageModelRegistry::read_global(cx);
1449            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1450                self.attach_tool_results(cx);
1451                if !canceled {
1452                    self.send_to_model(model, RequestKind::Chat, cx);
1453                }
1454            }
1455        }
1456
1457        cx.emit(ThreadEvent::ToolFinished {
1458            tool_use_id,
1459            pending_tool_use,
1460        });
1461    }
1462
1463    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1464        // Insert a user message to contain the tool results.
1465        self.insert_user_message(
1466            // TODO: Sending up a user message without any content results in the model sending back
1467            // responses that also don't have any content. We currently don't handle this case well,
1468            // so for now we provide some text to keep the model on track.
1469            "Here are the tool results.",
1470            Vec::new(),
1471            None,
1472            cx,
1473        );
1474    }
1475
1476    /// Cancels the last pending completion, if there are any pending.
1477    ///
1478    /// Returns whether a completion was canceled.
1479    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1480        let canceled = if self.pending_completions.pop().is_some() {
1481            true
1482        } else {
1483            let mut canceled = false;
1484            for pending_tool_use in self.tool_use.cancel_pending() {
1485                canceled = true;
1486                self.tool_finished(
1487                    pending_tool_use.id.clone(),
1488                    Some(pending_tool_use),
1489                    true,
1490                    cx,
1491                );
1492            }
1493            canceled
1494        };
1495        self.finalize_pending_checkpoint(cx);
1496        canceled
1497    }
1498
1499    pub fn feedback(&self) -> Option<ThreadFeedback> {
1500        self.feedback
1501    }
1502
1503    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1504        self.message_feedback.get(&message_id).copied()
1505    }
1506
1507    pub fn report_message_feedback(
1508        &mut self,
1509        message_id: MessageId,
1510        feedback: ThreadFeedback,
1511        cx: &mut Context<Self>,
1512    ) -> Task<Result<()>> {
1513        if self.message_feedback.get(&message_id) == Some(&feedback) {
1514            return Task::ready(Ok(()));
1515        }
1516
1517        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1518        let serialized_thread = self.serialize(cx);
1519        let thread_id = self.id().clone();
1520        let client = self.project.read(cx).client();
1521
1522        let enabled_tool_names: Vec<String> = self
1523            .tools()
1524            .enabled_tools(cx)
1525            .iter()
1526            .map(|tool| tool.name().to_string())
1527            .collect();
1528
1529        self.message_feedback.insert(message_id, feedback);
1530
1531        cx.notify();
1532
1533        let message_content = self
1534            .message(message_id)
1535            .map(|msg| msg.to_string())
1536            .unwrap_or_default();
1537
1538        cx.background_spawn(async move {
1539            let final_project_snapshot = final_project_snapshot.await;
1540            let serialized_thread = serialized_thread.await?;
1541            let thread_data =
1542                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1543
1544            let rating = match feedback {
1545                ThreadFeedback::Positive => "positive",
1546                ThreadFeedback::Negative => "negative",
1547            };
1548            telemetry::event!(
1549                "Assistant Thread Rated",
1550                rating,
1551                thread_id,
1552                enabled_tool_names,
1553                message_id = message_id.0,
1554                message_content,
1555                thread_data,
1556                final_project_snapshot
1557            );
1558            client.telemetry().flush_events();
1559
1560            Ok(())
1561        })
1562    }
1563
1564    pub fn report_feedback(
1565        &mut self,
1566        feedback: ThreadFeedback,
1567        cx: &mut Context<Self>,
1568    ) -> Task<Result<()>> {
1569        let last_assistant_message_id = self
1570            .messages
1571            .iter()
1572            .rev()
1573            .find(|msg| msg.role == Role::Assistant)
1574            .map(|msg| msg.id);
1575
1576        if let Some(message_id) = last_assistant_message_id {
1577            self.report_message_feedback(message_id, feedback, cx)
1578        } else {
1579            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1580            let serialized_thread = self.serialize(cx);
1581            let thread_id = self.id().clone();
1582            let client = self.project.read(cx).client();
1583            self.feedback = Some(feedback);
1584            cx.notify();
1585
1586            cx.background_spawn(async move {
1587                let final_project_snapshot = final_project_snapshot.await;
1588                let serialized_thread = serialized_thread.await?;
1589                let thread_data = serde_json::to_value(serialized_thread)
1590                    .unwrap_or_else(|_| serde_json::Value::Null);
1591
1592                let rating = match feedback {
1593                    ThreadFeedback::Positive => "positive",
1594                    ThreadFeedback::Negative => "negative",
1595                };
1596                telemetry::event!(
1597                    "Assistant Thread Rated",
1598                    rating,
1599                    thread_id,
1600                    thread_data,
1601                    final_project_snapshot
1602                );
1603                client.telemetry().flush_events();
1604
1605                Ok(())
1606            })
1607        }
1608    }
1609
1610    /// Create a snapshot of the current project state including git information and unsaved buffers.
1611    fn project_snapshot(
1612        project: Entity<Project>,
1613        cx: &mut Context<Self>,
1614    ) -> Task<Arc<ProjectSnapshot>> {
1615        let git_store = project.read(cx).git_store().clone();
1616        let worktree_snapshots: Vec<_> = project
1617            .read(cx)
1618            .visible_worktrees(cx)
1619            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1620            .collect();
1621
1622        cx.spawn(async move |_, cx| {
1623            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1624
1625            let mut unsaved_buffers = Vec::new();
1626            cx.update(|app_cx| {
1627                let buffer_store = project.read(app_cx).buffer_store();
1628                for buffer_handle in buffer_store.read(app_cx).buffers() {
1629                    let buffer = buffer_handle.read(app_cx);
1630                    if buffer.is_dirty() {
1631                        if let Some(file) = buffer.file() {
1632                            let path = file.path().to_string_lossy().to_string();
1633                            unsaved_buffers.push(path);
1634                        }
1635                    }
1636                }
1637            })
1638            .ok();
1639
1640            Arc::new(ProjectSnapshot {
1641                worktree_snapshots,
1642                unsaved_buffer_paths: unsaved_buffers,
1643                timestamp: Utc::now(),
1644            })
1645        })
1646    }
1647
1648    fn worktree_snapshot(
1649        worktree: Entity<project::Worktree>,
1650        git_store: Entity<GitStore>,
1651        cx: &App,
1652    ) -> Task<WorktreeSnapshot> {
1653        cx.spawn(async move |cx| {
1654            // Get worktree path and snapshot
1655            let worktree_info = cx.update(|app_cx| {
1656                let worktree = worktree.read(app_cx);
1657                let path = worktree.abs_path().to_string_lossy().to_string();
1658                let snapshot = worktree.snapshot();
1659                (path, snapshot)
1660            });
1661
1662            let Ok((worktree_path, _snapshot)) = worktree_info else {
1663                return WorktreeSnapshot {
1664                    worktree_path: String::new(),
1665                    git_state: None,
1666                };
1667            };
1668
1669            let git_state = git_store
1670                .update(cx, |git_store, cx| {
1671                    git_store
1672                        .repositories()
1673                        .values()
1674                        .find(|repo| {
1675                            repo.read(cx)
1676                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1677                                .is_some()
1678                        })
1679                        .cloned()
1680                })
1681                .ok()
1682                .flatten()
1683                .map(|repo| {
1684                    repo.update(cx, |repo, _| {
1685                        let current_branch =
1686                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1687                        repo.send_job(None, |state, _| async move {
1688                            let RepositoryState::Local { backend, .. } = state else {
1689                                return GitState {
1690                                    remote_url: None,
1691                                    head_sha: None,
1692                                    current_branch,
1693                                    diff: None,
1694                                };
1695                            };
1696
1697                            let remote_url = backend.remote_url("origin");
1698                            let head_sha = backend.head_sha();
1699                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1700
1701                            GitState {
1702                                remote_url,
1703                                head_sha,
1704                                current_branch,
1705                                diff,
1706                            }
1707                        })
1708                    })
1709                });
1710
1711            let git_state = match git_state {
1712                Some(git_state) => match git_state.ok() {
1713                    Some(git_state) => git_state.await.ok(),
1714                    None => None,
1715                },
1716                None => None,
1717            };
1718
1719            WorktreeSnapshot {
1720                worktree_path,
1721                git_state,
1722            }
1723        })
1724    }
1725
1726    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1727        let mut markdown = Vec::new();
1728
1729        if let Some(summary) = self.summary() {
1730            writeln!(markdown, "# {summary}\n")?;
1731        };
1732
1733        for message in self.messages() {
1734            writeln!(
1735                markdown,
1736                "## {role}\n",
1737                role = match message.role {
1738                    Role::User => "User",
1739                    Role::Assistant => "Assistant",
1740                    Role::System => "System",
1741                }
1742            )?;
1743
1744            if !message.context.is_empty() {
1745                writeln!(markdown, "{}", message.context)?;
1746            }
1747
1748            for segment in &message.segments {
1749                match segment {
1750                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1751                    MessageSegment::Thinking(text) => {
1752                        writeln!(markdown, "<think>{}</think>\n", text)?
1753                    }
1754                }
1755            }
1756
1757            for tool_use in self.tool_uses_for_message(message.id, cx) {
1758                writeln!(
1759                    markdown,
1760                    "**Use Tool: {} ({})**",
1761                    tool_use.name, tool_use.id
1762                )?;
1763                writeln!(markdown, "```json")?;
1764                writeln!(
1765                    markdown,
1766                    "{}",
1767                    serde_json::to_string_pretty(&tool_use.input)?
1768                )?;
1769                writeln!(markdown, "```")?;
1770            }
1771
1772            for tool_result in self.tool_results_for_message(message.id) {
1773                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1774                if tool_result.is_error {
1775                    write!(markdown, " (Error)")?;
1776                }
1777
1778                writeln!(markdown, "**\n")?;
1779                writeln!(markdown, "{}", tool_result.content)?;
1780            }
1781        }
1782
1783        Ok(String::from_utf8_lossy(&markdown).to_string())
1784    }
1785
1786    pub fn keep_edits_in_range(
1787        &mut self,
1788        buffer: Entity<language::Buffer>,
1789        buffer_range: Range<language::Anchor>,
1790        cx: &mut Context<Self>,
1791    ) {
1792        self.action_log.update(cx, |action_log, cx| {
1793            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1794        });
1795    }
1796
1797    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1798        self.action_log
1799            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1800    }
1801
1802    pub fn reject_edits_in_range(
1803        &mut self,
1804        buffer: Entity<language::Buffer>,
1805        buffer_range: Range<language::Anchor>,
1806        cx: &mut Context<Self>,
1807    ) -> Task<Result<()>> {
1808        self.action_log.update(cx, |action_log, cx| {
1809            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1810        })
1811    }
1812
1813    pub fn action_log(&self) -> &Entity<ActionLog> {
1814        &self.action_log
1815    }
1816
1817    pub fn project(&self) -> &Entity<Project> {
1818        &self.project
1819    }
1820
1821    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1822        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1823            return;
1824        }
1825
1826        let now = Instant::now();
1827        if let Some(last) = self.last_auto_capture_at {
1828            if now.duration_since(last).as_secs() < 10 {
1829                return;
1830            }
1831        }
1832
1833        self.last_auto_capture_at = Some(now);
1834
1835        let thread_id = self.id().clone();
1836        let github_login = self
1837            .project
1838            .read(cx)
1839            .user_store()
1840            .read(cx)
1841            .current_user()
1842            .map(|user| user.github_login.clone());
1843        let client = self.project.read(cx).client().clone();
1844        let serialize_task = self.serialize(cx);
1845
1846        cx.background_executor()
1847            .spawn(async move {
1848                if let Ok(serialized_thread) = serialize_task.await {
1849                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1850                        telemetry::event!(
1851                            "Agent Thread Auto-Captured",
1852                            thread_id = thread_id.to_string(),
1853                            thread_data = thread_data,
1854                            auto_capture_reason = "tracked_user",
1855                            github_login = github_login
1856                        );
1857
1858                        client.telemetry().flush_events();
1859                    }
1860                }
1861            })
1862            .detach();
1863    }
1864
1865    pub fn cumulative_token_usage(&self) -> TokenUsage {
1866        self.cumulative_token_usage
1867    }
1868
1869    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1870        let model_registry = LanguageModelRegistry::read_global(cx);
1871        let Some(model) = model_registry.default_model() else {
1872            return TotalTokenUsage::default();
1873        };
1874
1875        let max = model.model.max_token_count();
1876
1877        if let Some(exceeded_error) = &self.exceeded_window_error {
1878            if model.model.id() == exceeded_error.model_id {
1879                return TotalTokenUsage {
1880                    total: exceeded_error.token_count,
1881                    max,
1882                    ratio: TokenUsageRatio::Exceeded,
1883                };
1884            }
1885        }
1886
1887        #[cfg(debug_assertions)]
1888        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1889            .unwrap_or("0.8".to_string())
1890            .parse()
1891            .unwrap();
1892        #[cfg(not(debug_assertions))]
1893        let warning_threshold: f32 = 0.8;
1894
1895        let total = self.cumulative_token_usage.total_tokens() as usize;
1896
1897        let ratio = if total >= max {
1898            TokenUsageRatio::Exceeded
1899        } else if total as f32 / max as f32 >= warning_threshold {
1900            TokenUsageRatio::Warning
1901        } else {
1902            TokenUsageRatio::Normal
1903        };
1904
1905        TotalTokenUsage { total, max, ratio }
1906    }
1907
1908    pub fn deny_tool_use(
1909        &mut self,
1910        tool_use_id: LanguageModelToolUseId,
1911        tool_name: Arc<str>,
1912        cx: &mut Context<Self>,
1913    ) {
1914        let err = Err(anyhow::anyhow!(
1915            "Permission to run tool action denied by user"
1916        ));
1917
1918        self.tool_use
1919            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1920        self.tool_finished(tool_use_id.clone(), None, true, cx);
1921    }
1922}
1923
1924#[derive(Debug, Clone, Error)]
1925pub enum ThreadError {
1926    #[error("Payment required")]
1927    PaymentRequired,
1928    #[error("Max monthly spend reached")]
1929    MaxMonthlySpendReached,
1930    #[error("Message {header}: {message}")]
1931    Message {
1932        header: SharedString,
1933        message: SharedString,
1934    },
1935}
1936
1937#[derive(Debug, Clone)]
1938pub enum ThreadEvent {
1939    ShowError(ThreadError),
1940    StreamedCompletion,
1941    StreamedAssistantText(MessageId, String),
1942    StreamedAssistantThinking(MessageId, String),
1943    Stopped(Result<StopReason, Arc<anyhow::Error>>),
1944    MessageAdded(MessageId),
1945    MessageEdited(MessageId),
1946    MessageDeleted(MessageId),
1947    SummaryGenerated,
1948    SummaryChanged,
1949    UsePendingTools {
1950        tool_uses: Vec<PendingToolUse>,
1951    },
1952    ToolFinished {
1953        #[allow(unused)]
1954        tool_use_id: LanguageModelToolUseId,
1955        /// The pending tool use that corresponds to this tool.
1956        pending_tool_use: Option<PendingToolUse>,
1957    },
1958    CheckpointChanged,
1959    ToolConfirmationNeeded,
1960}
1961
1962impl EventEmitter<ThreadEvent> for Thread {}
1963
1964struct PendingCompletion {
1965    id: usize,
1966    _task: Task<()>,
1967}
1968
1969#[cfg(test)]
1970mod tests {
1971    use super::*;
1972    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1973    use assistant_settings::AssistantSettings;
1974    use context_server::ContextServerSettings;
1975    use editor::EditorSettings;
1976    use gpui::TestAppContext;
1977    use project::{FakeFs, Project};
1978    use prompt_store::PromptBuilder;
1979    use serde_json::json;
1980    use settings::{Settings, SettingsStore};
1981    use std::sync::Arc;
1982    use theme::ThemeSettings;
1983    use util::path;
1984    use workspace::Workspace;
1985
1986    #[gpui::test]
1987    async fn test_message_with_context(cx: &mut TestAppContext) {
1988        init_test_settings(cx);
1989
1990        let project = create_test_project(
1991            cx,
1992            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1993        )
1994        .await;
1995
1996        let (_workspace, _thread_store, thread, context_store) =
1997            setup_test_environment(cx, project.clone()).await;
1998
1999        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2000            .await
2001            .unwrap();
2002
2003        let context =
2004            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2005
2006        // Insert user message with context
2007        let message_id = thread.update(cx, |thread, cx| {
2008            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2009        });
2010
2011        // Check content and context in message object
2012        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2013
2014        // Use different path format strings based on platform for the test
2015        #[cfg(windows)]
2016        let path_part = r"test\code.rs";
2017        #[cfg(not(windows))]
2018        let path_part = "test/code.rs";
2019
2020        let expected_context = format!(
2021            r#"
2022<context>
2023The following items were attached by the user. You don't need to use other tools to read them.
2024
2025<files>
2026```rs {path_part}
2027fn main() {{
2028    println!("Hello, world!");
2029}}
2030```
2031</files>
2032</context>
2033"#
2034        );
2035
2036        assert_eq!(message.role, Role::User);
2037        assert_eq!(message.segments.len(), 1);
2038        assert_eq!(
2039            message.segments[0],
2040            MessageSegment::Text("Please explain this code".to_string())
2041        );
2042        assert_eq!(message.context, expected_context);
2043
2044        // Check message in request
2045        let request = thread.read_with(cx, |thread, cx| {
2046            thread.to_completion_request(RequestKind::Chat, cx)
2047        });
2048
2049        assert_eq!(request.messages.len(), 2);
2050        let expected_full_message = format!("{}Please explain this code", expected_context);
2051        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2052    }
2053
2054    #[gpui::test]
2055    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2056        init_test_settings(cx);
2057
2058        let project = create_test_project(
2059            cx,
2060            json!({
2061                "file1.rs": "fn function1() {}\n",
2062                "file2.rs": "fn function2() {}\n",
2063                "file3.rs": "fn function3() {}\n",
2064            }),
2065        )
2066        .await;
2067
2068        let (_, _thread_store, thread, context_store) =
2069            setup_test_environment(cx, project.clone()).await;
2070
2071        // Open files individually
2072        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2073            .await
2074            .unwrap();
2075        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2076            .await
2077            .unwrap();
2078        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2079            .await
2080            .unwrap();
2081
2082        // Get the context objects
2083        let contexts = context_store.update(cx, |store, _| store.context().clone());
2084        assert_eq!(contexts.len(), 3);
2085
2086        // First message with context 1
2087        let message1_id = thread.update(cx, |thread, cx| {
2088            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2089        });
2090
2091        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2092        let message2_id = thread.update(cx, |thread, cx| {
2093            thread.insert_user_message(
2094                "Message 2",
2095                vec![contexts[0].clone(), contexts[1].clone()],
2096                None,
2097                cx,
2098            )
2099        });
2100
2101        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2102        let message3_id = thread.update(cx, |thread, cx| {
2103            thread.insert_user_message(
2104                "Message 3",
2105                vec![
2106                    contexts[0].clone(),
2107                    contexts[1].clone(),
2108                    contexts[2].clone(),
2109                ],
2110                None,
2111                cx,
2112            )
2113        });
2114
2115        // Check what contexts are included in each message
2116        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2117            (
2118                thread.message(message1_id).unwrap().clone(),
2119                thread.message(message2_id).unwrap().clone(),
2120                thread.message(message3_id).unwrap().clone(),
2121            )
2122        });
2123
2124        // First message should include context 1
2125        assert!(message1.context.contains("file1.rs"));
2126
2127        // Second message should include only context 2 (not 1)
2128        assert!(!message2.context.contains("file1.rs"));
2129        assert!(message2.context.contains("file2.rs"));
2130
2131        // Third message should include only context 3 (not 1 or 2)
2132        assert!(!message3.context.contains("file1.rs"));
2133        assert!(!message3.context.contains("file2.rs"));
2134        assert!(message3.context.contains("file3.rs"));
2135
2136        // Check entire request to make sure all contexts are properly included
2137        let request = thread.read_with(cx, |thread, cx| {
2138            thread.to_completion_request(RequestKind::Chat, cx)
2139        });
2140
2141        // The request should contain all 3 messages
2142        assert_eq!(request.messages.len(), 4);
2143
2144        // Check that the contexts are properly formatted in each message
2145        assert!(request.messages[1].string_contents().contains("file1.rs"));
2146        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2147        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2148
2149        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2150        assert!(request.messages[2].string_contents().contains("file2.rs"));
2151        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2152
2153        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2154        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2155        assert!(request.messages[3].string_contents().contains("file3.rs"));
2156    }
2157
2158    #[gpui::test]
2159    async fn test_message_without_files(cx: &mut TestAppContext) {
2160        init_test_settings(cx);
2161
2162        let project = create_test_project(
2163            cx,
2164            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2165        )
2166        .await;
2167
2168        let (_, _thread_store, thread, _context_store) =
2169            setup_test_environment(cx, project.clone()).await;
2170
2171        // Insert user message without any context (empty context vector)
2172        let message_id = thread.update(cx, |thread, cx| {
2173            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2174        });
2175
2176        // Check content and context in message object
2177        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2178
2179        // Context should be empty when no files are included
2180        assert_eq!(message.role, Role::User);
2181        assert_eq!(message.segments.len(), 1);
2182        assert_eq!(
2183            message.segments[0],
2184            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2185        );
2186        assert_eq!(message.context, "");
2187
2188        // Check message in request
2189        let request = thread.read_with(cx, |thread, cx| {
2190            thread.to_completion_request(RequestKind::Chat, cx)
2191        });
2192
2193        assert_eq!(request.messages.len(), 2);
2194        assert_eq!(
2195            request.messages[1].string_contents(),
2196            "What is the best way to learn Rust?"
2197        );
2198
2199        // Add second message, also without context
2200        let message2_id = thread.update(cx, |thread, cx| {
2201            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2202        });
2203
2204        let message2 =
2205            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2206        assert_eq!(message2.context, "");
2207
2208        // Check that both messages appear in the request
2209        let request = thread.read_with(cx, |thread, cx| {
2210            thread.to_completion_request(RequestKind::Chat, cx)
2211        });
2212
2213        assert_eq!(request.messages.len(), 3);
2214        assert_eq!(
2215            request.messages[1].string_contents(),
2216            "What is the best way to learn Rust?"
2217        );
2218        assert_eq!(
2219            request.messages[2].string_contents(),
2220            "Are there any good books?"
2221        );
2222    }
2223
2224    #[gpui::test]
2225    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2226        init_test_settings(cx);
2227
2228        let project = create_test_project(
2229            cx,
2230            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2231        )
2232        .await;
2233
2234        let (_workspace, _thread_store, thread, context_store) =
2235            setup_test_environment(cx, project.clone()).await;
2236
2237        // Open buffer and add it to context
2238        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2239            .await
2240            .unwrap();
2241
2242        let context =
2243            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2244
2245        // Insert user message with the buffer as context
2246        thread.update(cx, |thread, cx| {
2247            thread.insert_user_message("Explain this code", vec![context], None, cx)
2248        });
2249
2250        // Create a request and check that it doesn't have a stale buffer warning yet
2251        let initial_request = thread.read_with(cx, |thread, cx| {
2252            thread.to_completion_request(RequestKind::Chat, cx)
2253        });
2254
2255        // Make sure we don't have a stale file warning yet
2256        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2257            msg.string_contents()
2258                .contains("These files changed since last read:")
2259        });
2260        assert!(
2261            !has_stale_warning,
2262            "Should not have stale buffer warning before buffer is modified"
2263        );
2264
2265        // Modify the buffer
2266        buffer.update(cx, |buffer, cx| {
2267            // Find a position at the end of line 1
2268            buffer.edit(
2269                [(1..1, "\n    println!(\"Added a new line\");\n")],
2270                None,
2271                cx,
2272            );
2273        });
2274
2275        // Insert another user message without context
2276        thread.update(cx, |thread, cx| {
2277            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2278        });
2279
2280        // Create a new request and check for the stale buffer warning
2281        let new_request = thread.read_with(cx, |thread, cx| {
2282            thread.to_completion_request(RequestKind::Chat, cx)
2283        });
2284
2285        // We should have a stale file warning as the last message
2286        let last_message = new_request
2287            .messages
2288            .last()
2289            .expect("Request should have messages");
2290
2291        // The last message should be the stale buffer notification
2292        assert_eq!(last_message.role, Role::User);
2293
2294        // Check the exact content of the message
2295        let expected_content = "These files changed since last read:\n- code.rs\n";
2296        assert_eq!(
2297            last_message.string_contents(),
2298            expected_content,
2299            "Last message should be exactly the stale buffer notification"
2300        );
2301    }
2302
2303    fn init_test_settings(cx: &mut TestAppContext) {
2304        cx.update(|cx| {
2305            let settings_store = SettingsStore::test(cx);
2306            cx.set_global(settings_store);
2307            language::init(cx);
2308            Project::init_settings(cx);
2309            AssistantSettings::register(cx);
2310            thread_store::init(cx);
2311            workspace::init_settings(cx);
2312            ThemeSettings::register(cx);
2313            ContextServerSettings::register(cx);
2314            EditorSettings::register(cx);
2315        });
2316    }
2317
2318    // Helper to create a test project with test files
2319    async fn create_test_project(
2320        cx: &mut TestAppContext,
2321        files: serde_json::Value,
2322    ) -> Entity<Project> {
2323        let fs = FakeFs::new(cx.executor());
2324        fs.insert_tree(path!("/test"), files).await;
2325        Project::test(fs, [path!("/test").as_ref()], cx).await
2326    }
2327
2328    async fn setup_test_environment(
2329        cx: &mut TestAppContext,
2330        project: Entity<Project>,
2331    ) -> (
2332        Entity<Workspace>,
2333        Entity<ThreadStore>,
2334        Entity<Thread>,
2335        Entity<ContextStore>,
2336    ) {
2337        let (workspace, cx) =
2338            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2339
2340        let thread_store = cx
2341            .update(|_, cx| {
2342                ThreadStore::load(
2343                    project.clone(),
2344                    Arc::default(),
2345                    Arc::new(PromptBuilder::new(None).unwrap()),
2346                    cx,
2347                )
2348            })
2349            .await;
2350
2351        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2352        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2353
2354        (workspace, thread_store, thread, context_store)
2355    }
2356
2357    async fn add_file_to_context(
2358        project: &Entity<Project>,
2359        context_store: &Entity<ContextStore>,
2360        path: &str,
2361        cx: &mut TestAppContext,
2362    ) -> Result<Entity<language::Buffer>> {
2363        let buffer_path = project
2364            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2365            .unwrap();
2366
2367        let buffer = project
2368            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2369            .await
2370            .unwrap();
2371
2372        context_store
2373            .update(cx, |store, cx| {
2374                store.add_file_from_buffer(buffer.clone(), cx)
2375            })
2376            .await?;
2377
2378        Ok(buffer)
2379    }
2380}