thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  22    Role, StopReason, TokenUsage,
  23};
  24use project::Project;
  25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  26use prompt_store::PromptBuilder;
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use thiserror::Error;
  31use util::{ResultExt as _, TryFutureExt as _, post_inc};
  32use uuid::Uuid;
  33
  34use crate::context::{AssistantContext, ContextId, format_context_as_string};
  35use crate::thread_store::{
  36    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  37    SerializedToolUse, SharedProjectContext,
  38};
  39use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  40
  41#[derive(Debug, Clone, Copy)]
  42pub enum RequestKind {
  43    Chat,
  44    /// Used when summarizing a thread.
  45    Summarize,
  46}
  47
  48#[derive(
  49    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  50)]
  51pub struct ThreadId(Arc<str>);
  52
  53impl ThreadId {
  54    pub fn new() -> Self {
  55        Self(Uuid::new_v4().to_string().into())
  56    }
  57}
  58
  59impl std::fmt::Display for ThreadId {
  60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  61        write!(f, "{}", self.0)
  62    }
  63}
  64
  65impl From<&str> for ThreadId {
  66    fn from(value: &str) -> Self {
  67        Self(value.into())
  68    }
  69}
  70
  71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  72pub struct MessageId(pub(crate) usize);
  73
  74impl MessageId {
  75    fn post_inc(&mut self) -> Self {
  76        Self(post_inc(&mut self.0))
  77    }
  78}
  79
  80/// A message in a [`Thread`].
  81#[derive(Debug, Clone)]
  82pub struct Message {
  83    pub id: MessageId,
  84    pub role: Role,
  85    pub segments: Vec<MessageSegment>,
  86    pub context: String,
  87}
  88
  89impl Message {
  90    /// Returns whether the message contains any meaningful text that should be displayed
  91    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  92    pub fn should_display_content(&self) -> bool {
  93        self.segments.iter().all(|segment| segment.should_display())
  94    }
  95
  96    pub fn push_thinking(&mut self, text: &str) {
  97        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  98            segment.push_str(text);
  99        } else {
 100            self.segments
 101                .push(MessageSegment::Thinking(text.to_string()));
 102        }
 103    }
 104
 105    pub fn push_text(&mut self, text: &str) {
 106        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 107            segment.push_str(text);
 108        } else {
 109            self.segments.push(MessageSegment::Text(text.to_string()));
 110        }
 111    }
 112
 113    pub fn to_string(&self) -> String {
 114        let mut result = String::new();
 115
 116        if !self.context.is_empty() {
 117            result.push_str(&self.context);
 118        }
 119
 120        for segment in &self.segments {
 121            match segment {
 122                MessageSegment::Text(text) => result.push_str(text),
 123                MessageSegment::Thinking(text) => {
 124                    result.push_str("<think>");
 125                    result.push_str(text);
 126                    result.push_str("</think>");
 127                }
 128            }
 129        }
 130
 131        result
 132    }
 133}
 134
 135#[derive(Debug, Clone, PartialEq, Eq)]
 136pub enum MessageSegment {
 137    Text(String),
 138    Thinking(String),
 139}
 140
 141impl MessageSegment {
 142    pub fn text_mut(&mut self) -> &mut String {
 143        match self {
 144            Self::Text(text) => text,
 145            Self::Thinking(text) => text,
 146        }
 147    }
 148
 149    pub fn should_display(&self) -> bool {
 150        // We add USING_TOOL_MARKER when making a request that includes tool uses
 151        // without non-whitespace text around them, and this can cause the model
 152        // to mimic the pattern, so we consider those segments not displayable.
 153        match self {
 154            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 155            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 156        }
 157    }
 158}
 159
 160#[derive(Debug, Clone, Serialize, Deserialize)]
 161pub struct ProjectSnapshot {
 162    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 163    pub unsaved_buffer_paths: Vec<String>,
 164    pub timestamp: DateTime<Utc>,
 165}
 166
 167#[derive(Debug, Clone, Serialize, Deserialize)]
 168pub struct WorktreeSnapshot {
 169    pub worktree_path: String,
 170    pub git_state: Option<GitState>,
 171}
 172
 173#[derive(Debug, Clone, Serialize, Deserialize)]
 174pub struct GitState {
 175    pub remote_url: Option<String>,
 176    pub head_sha: Option<String>,
 177    pub current_branch: Option<String>,
 178    pub diff: Option<String>,
 179}
 180
 181#[derive(Clone)]
 182pub struct ThreadCheckpoint {
 183    message_id: MessageId,
 184    git_checkpoint: GitStoreCheckpoint,
 185}
 186
 187#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 188pub enum ThreadFeedback {
 189    Positive,
 190    Negative,
 191}
 192
 193pub enum LastRestoreCheckpoint {
 194    Pending {
 195        message_id: MessageId,
 196    },
 197    Error {
 198        message_id: MessageId,
 199        error: String,
 200    },
 201}
 202
 203impl LastRestoreCheckpoint {
 204    pub fn message_id(&self) -> MessageId {
 205        match self {
 206            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 207            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 208        }
 209    }
 210}
 211
 212#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 213pub enum DetailedSummaryState {
 214    #[default]
 215    NotGenerated,
 216    Generating {
 217        message_id: MessageId,
 218    },
 219    Generated {
 220        text: SharedString,
 221        message_id: MessageId,
 222    },
 223}
 224
 225#[derive(Default)]
 226pub struct TotalTokenUsage {
 227    pub total: usize,
 228    pub max: usize,
 229    pub ratio: TokenUsageRatio,
 230}
 231
 232#[derive(Debug, Default, PartialEq, Eq)]
 233pub enum TokenUsageRatio {
 234    #[default]
 235    Normal,
 236    Warning,
 237    Exceeded,
 238}
 239
 240/// A thread of conversation with the LLM.
 241pub struct Thread {
 242    id: ThreadId,
 243    updated_at: DateTime<Utc>,
 244    summary: Option<SharedString>,
 245    pending_summary: Task<Option<()>>,
 246    detailed_summary_state: DetailedSummaryState,
 247    messages: Vec<Message>,
 248    next_message_id: MessageId,
 249    context: BTreeMap<ContextId, AssistantContext>,
 250    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 251    project_context: SharedProjectContext,
 252    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 253    completion_count: usize,
 254    pending_completions: Vec<PendingCompletion>,
 255    project: Entity<Project>,
 256    prompt_builder: Arc<PromptBuilder>,
 257    tools: Arc<ToolWorkingSet>,
 258    tool_use: ToolUseState,
 259    action_log: Entity<ActionLog>,
 260    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 261    pending_checkpoint: Option<ThreadCheckpoint>,
 262    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 263    cumulative_token_usage: TokenUsage,
 264    exceeded_window_error: Option<ExceededWindowError>,
 265    feedback: Option<ThreadFeedback>,
 266    message_feedback: HashMap<MessageId, ThreadFeedback>,
 267    last_auto_capture_at: Option<Instant>,
 268}
 269
 270#[derive(Debug, Clone, Serialize, Deserialize)]
 271pub struct ExceededWindowError {
 272    /// Model used when last message exceeded context window
 273    model_id: LanguageModelId,
 274    /// Token count including last message
 275    token_count: usize,
 276}
 277
 278impl Thread {
 279    pub fn new(
 280        project: Entity<Project>,
 281        tools: Arc<ToolWorkingSet>,
 282        prompt_builder: Arc<PromptBuilder>,
 283        system_prompt: SharedProjectContext,
 284        cx: &mut Context<Self>,
 285    ) -> Self {
 286        Self {
 287            id: ThreadId::new(),
 288            updated_at: Utc::now(),
 289            summary: None,
 290            pending_summary: Task::ready(None),
 291            detailed_summary_state: DetailedSummaryState::NotGenerated,
 292            messages: Vec::new(),
 293            next_message_id: MessageId(0),
 294            context: BTreeMap::default(),
 295            context_by_message: HashMap::default(),
 296            project_context: system_prompt,
 297            checkpoints_by_message: HashMap::default(),
 298            completion_count: 0,
 299            pending_completions: Vec::new(),
 300            project: project.clone(),
 301            prompt_builder,
 302            tools: tools.clone(),
 303            last_restore_checkpoint: None,
 304            pending_checkpoint: None,
 305            tool_use: ToolUseState::new(tools.clone()),
 306            action_log: cx.new(|_| ActionLog::new(project.clone())),
 307            initial_project_snapshot: {
 308                let project_snapshot = Self::project_snapshot(project, cx);
 309                cx.foreground_executor()
 310                    .spawn(async move { Some(project_snapshot.await) })
 311                    .shared()
 312            },
 313            cumulative_token_usage: TokenUsage::default(),
 314            exceeded_window_error: None,
 315            feedback: None,
 316            message_feedback: HashMap::default(),
 317            last_auto_capture_at: None,
 318        }
 319    }
 320
 321    pub fn deserialize(
 322        id: ThreadId,
 323        serialized: SerializedThread,
 324        project: Entity<Project>,
 325        tools: Arc<ToolWorkingSet>,
 326        prompt_builder: Arc<PromptBuilder>,
 327        project_context: SharedProjectContext,
 328        cx: &mut Context<Self>,
 329    ) -> Self {
 330        let next_message_id = MessageId(
 331            serialized
 332                .messages
 333                .last()
 334                .map(|message| message.id.0 + 1)
 335                .unwrap_or(0),
 336        );
 337        let tool_use =
 338            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 339
 340        Self {
 341            id,
 342            updated_at: serialized.updated_at,
 343            summary: Some(serialized.summary),
 344            pending_summary: Task::ready(None),
 345            detailed_summary_state: serialized.detailed_summary_state,
 346            messages: serialized
 347                .messages
 348                .into_iter()
 349                .map(|message| Message {
 350                    id: message.id,
 351                    role: message.role,
 352                    segments: message
 353                        .segments
 354                        .into_iter()
 355                        .map(|segment| match segment {
 356                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 357                            SerializedMessageSegment::Thinking { text } => {
 358                                MessageSegment::Thinking(text)
 359                            }
 360                        })
 361                        .collect(),
 362                    context: message.context,
 363                })
 364                .collect(),
 365            next_message_id,
 366            context: BTreeMap::default(),
 367            context_by_message: HashMap::default(),
 368            project_context,
 369            checkpoints_by_message: HashMap::default(),
 370            completion_count: 0,
 371            pending_completions: Vec::new(),
 372            last_restore_checkpoint: None,
 373            pending_checkpoint: None,
 374            project: project.clone(),
 375            prompt_builder,
 376            tools,
 377            tool_use,
 378            action_log: cx.new(|_| ActionLog::new(project)),
 379            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 380            cumulative_token_usage: serialized.cumulative_token_usage,
 381            exceeded_window_error: None,
 382            feedback: None,
 383            message_feedback: HashMap::default(),
 384            last_auto_capture_at: None,
 385        }
 386    }
 387
 388    pub fn id(&self) -> &ThreadId {
 389        &self.id
 390    }
 391
 392    pub fn is_empty(&self) -> bool {
 393        self.messages.is_empty()
 394    }
 395
 396    pub fn updated_at(&self) -> DateTime<Utc> {
 397        self.updated_at
 398    }
 399
 400    pub fn touch_updated_at(&mut self) {
 401        self.updated_at = Utc::now();
 402    }
 403
 404    pub fn summary(&self) -> Option<SharedString> {
 405        self.summary.clone()
 406    }
 407
 408    pub fn project_context(&self) -> SharedProjectContext {
 409        self.project_context.clone()
 410    }
 411
 412    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 413
 414    pub fn summary_or_default(&self) -> SharedString {
 415        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 416    }
 417
 418    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 419        let Some(current_summary) = &self.summary else {
 420            // Don't allow setting summary until generated
 421            return;
 422        };
 423
 424        let mut new_summary = new_summary.into();
 425
 426        if new_summary.is_empty() {
 427            new_summary = Self::DEFAULT_SUMMARY;
 428        }
 429
 430        if current_summary != &new_summary {
 431            self.summary = Some(new_summary);
 432            cx.emit(ThreadEvent::SummaryChanged);
 433        }
 434    }
 435
 436    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 437        self.latest_detailed_summary()
 438            .unwrap_or_else(|| self.text().into())
 439    }
 440
 441    fn latest_detailed_summary(&self) -> Option<SharedString> {
 442        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 443            Some(text.clone())
 444        } else {
 445            None
 446        }
 447    }
 448
 449    pub fn message(&self, id: MessageId) -> Option<&Message> {
 450        self.messages.iter().find(|message| message.id == id)
 451    }
 452
 453    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 454        self.messages.iter()
 455    }
 456
 457    pub fn is_generating(&self) -> bool {
 458        !self.pending_completions.is_empty() || !self.all_tools_finished()
 459    }
 460
 461    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 462        &self.tools
 463    }
 464
 465    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 466        self.tool_use
 467            .pending_tool_uses()
 468            .into_iter()
 469            .find(|tool_use| &tool_use.id == id)
 470    }
 471
 472    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 473        self.tool_use
 474            .pending_tool_uses()
 475            .into_iter()
 476            .filter(|tool_use| tool_use.status.needs_confirmation())
 477    }
 478
 479    pub fn has_pending_tool_uses(&self) -> bool {
 480        !self.tool_use.pending_tool_uses().is_empty()
 481    }
 482
 483    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 484        self.checkpoints_by_message.get(&id).cloned()
 485    }
 486
 487    pub fn restore_checkpoint(
 488        &mut self,
 489        checkpoint: ThreadCheckpoint,
 490        cx: &mut Context<Self>,
 491    ) -> Task<Result<()>> {
 492        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 493            message_id: checkpoint.message_id,
 494        });
 495        cx.emit(ThreadEvent::CheckpointChanged);
 496        cx.notify();
 497
 498        let git_store = self.project().read(cx).git_store().clone();
 499        let restore = git_store.update(cx, |git_store, cx| {
 500            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 501        });
 502
 503        cx.spawn(async move |this, cx| {
 504            let result = restore.await;
 505            this.update(cx, |this, cx| {
 506                if let Err(err) = result.as_ref() {
 507                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 508                        message_id: checkpoint.message_id,
 509                        error: err.to_string(),
 510                    });
 511                } else {
 512                    this.truncate(checkpoint.message_id, cx);
 513                    this.last_restore_checkpoint = None;
 514                }
 515                this.pending_checkpoint = None;
 516                cx.emit(ThreadEvent::CheckpointChanged);
 517                cx.notify();
 518            })?;
 519            result
 520        })
 521    }
 522
 523    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 524        let pending_checkpoint = if self.is_generating() {
 525            return;
 526        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 527            checkpoint
 528        } else {
 529            return;
 530        };
 531
 532        let git_store = self.project.read(cx).git_store().clone();
 533        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 534        cx.spawn(async move |this, cx| match final_checkpoint.await {
 535            Ok(final_checkpoint) => {
 536                let equal = git_store
 537                    .update(cx, |store, cx| {
 538                        store.compare_checkpoints(
 539                            pending_checkpoint.git_checkpoint.clone(),
 540                            final_checkpoint.clone(),
 541                            cx,
 542                        )
 543                    })?
 544                    .await
 545                    .unwrap_or(false);
 546
 547                if equal {
 548                    git_store
 549                        .update(cx, |store, cx| {
 550                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 551                        })?
 552                        .detach();
 553                } else {
 554                    this.update(cx, |this, cx| {
 555                        this.insert_checkpoint(pending_checkpoint, cx)
 556                    })?;
 557                }
 558
 559                git_store
 560                    .update(cx, |store, cx| {
 561                        store.delete_checkpoint(final_checkpoint, cx)
 562                    })?
 563                    .detach();
 564
 565                Ok(())
 566            }
 567            Err(_) => this.update(cx, |this, cx| {
 568                this.insert_checkpoint(pending_checkpoint, cx)
 569            }),
 570        })
 571        .detach();
 572    }
 573
 574    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 575        self.checkpoints_by_message
 576            .insert(checkpoint.message_id, checkpoint);
 577        cx.emit(ThreadEvent::CheckpointChanged);
 578        cx.notify();
 579    }
 580
 581    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 582        self.last_restore_checkpoint.as_ref()
 583    }
 584
 585    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 586        let Some(message_ix) = self
 587            .messages
 588            .iter()
 589            .rposition(|message| message.id == message_id)
 590        else {
 591            return;
 592        };
 593        for deleted_message in self.messages.drain(message_ix..) {
 594            self.context_by_message.remove(&deleted_message.id);
 595            self.checkpoints_by_message.remove(&deleted_message.id);
 596        }
 597        cx.notify();
 598    }
 599
 600    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 601        self.context_by_message
 602            .get(&id)
 603            .into_iter()
 604            .flat_map(|context| {
 605                context
 606                    .iter()
 607                    .filter_map(|context_id| self.context.get(&context_id))
 608            })
 609    }
 610
 611    /// Returns whether all of the tool uses have finished running.
 612    pub fn all_tools_finished(&self) -> bool {
 613        // If the only pending tool uses left are the ones with errors, then
 614        // that means that we've finished running all of the pending tools.
 615        self.tool_use
 616            .pending_tool_uses()
 617            .iter()
 618            .all(|tool_use| tool_use.status.is_error())
 619    }
 620
 621    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 622        self.tool_use.tool_uses_for_message(id, cx)
 623    }
 624
 625    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 626        self.tool_use.tool_results_for_message(id)
 627    }
 628
 629    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 630        self.tool_use.tool_result(id)
 631    }
 632
 633    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 634        self.tool_use.message_has_tool_results(message_id)
 635    }
 636
 637    pub fn insert_user_message(
 638        &mut self,
 639        text: impl Into<String>,
 640        context: Vec<AssistantContext>,
 641        git_checkpoint: Option<GitStoreCheckpoint>,
 642        cx: &mut Context<Self>,
 643    ) -> MessageId {
 644        let text = text.into();
 645
 646        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 647
 648        // Filter out contexts that have already been included in previous messages
 649        let new_context: Vec<_> = context
 650            .into_iter()
 651            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 652            .collect();
 653
 654        if !new_context.is_empty() {
 655            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 656                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 657                    message.context = context_string;
 658                }
 659            }
 660
 661            self.action_log.update(cx, |log, cx| {
 662                // Track all buffers added as context
 663                for ctx in &new_context {
 664                    match ctx {
 665                        AssistantContext::File(file_ctx) => {
 666                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 667                        }
 668                        AssistantContext::Directory(dir_ctx) => {
 669                            for context_buffer in &dir_ctx.context_buffers {
 670                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 671                            }
 672                        }
 673                        AssistantContext::Symbol(symbol_ctx) => {
 674                            log.buffer_added_as_context(
 675                                symbol_ctx.context_symbol.buffer.clone(),
 676                                cx,
 677                            );
 678                        }
 679                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 680                    }
 681                }
 682            });
 683        }
 684
 685        let context_ids = new_context
 686            .iter()
 687            .map(|context| context.id())
 688            .collect::<Vec<_>>();
 689        self.context.extend(
 690            new_context
 691                .into_iter()
 692                .map(|context| (context.id(), context)),
 693        );
 694        self.context_by_message.insert(message_id, context_ids);
 695
 696        if let Some(git_checkpoint) = git_checkpoint {
 697            self.pending_checkpoint = Some(ThreadCheckpoint {
 698                message_id,
 699                git_checkpoint,
 700            });
 701        }
 702
 703        self.auto_capture_telemetry(cx);
 704
 705        message_id
 706    }
 707
 708    pub fn insert_message(
 709        &mut self,
 710        role: Role,
 711        segments: Vec<MessageSegment>,
 712        cx: &mut Context<Self>,
 713    ) -> MessageId {
 714        let id = self.next_message_id.post_inc();
 715        self.messages.push(Message {
 716            id,
 717            role,
 718            segments,
 719            context: String::new(),
 720        });
 721        self.touch_updated_at();
 722        cx.emit(ThreadEvent::MessageAdded(id));
 723        id
 724    }
 725
 726    pub fn edit_message(
 727        &mut self,
 728        id: MessageId,
 729        new_role: Role,
 730        new_segments: Vec<MessageSegment>,
 731        cx: &mut Context<Self>,
 732    ) -> bool {
 733        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 734            return false;
 735        };
 736        message.role = new_role;
 737        message.segments = new_segments;
 738        self.touch_updated_at();
 739        cx.emit(ThreadEvent::MessageEdited(id));
 740        true
 741    }
 742
 743    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 744        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 745            return false;
 746        };
 747        self.messages.remove(index);
 748        self.context_by_message.remove(&id);
 749        self.touch_updated_at();
 750        cx.emit(ThreadEvent::MessageDeleted(id));
 751        true
 752    }
 753
 754    /// Returns the representation of this [`Thread`] in a textual form.
 755    ///
 756    /// This is the representation we use when attaching a thread as context to another thread.
 757    pub fn text(&self) -> String {
 758        let mut text = String::new();
 759
 760        for message in &self.messages {
 761            text.push_str(match message.role {
 762                language_model::Role::User => "User:",
 763                language_model::Role::Assistant => "Assistant:",
 764                language_model::Role::System => "System:",
 765            });
 766            text.push('\n');
 767
 768            for segment in &message.segments {
 769                match segment {
 770                    MessageSegment::Text(content) => text.push_str(content),
 771                    MessageSegment::Thinking(content) => {
 772                        text.push_str(&format!("<think>{}</think>", content))
 773                    }
 774                }
 775            }
 776            text.push('\n');
 777        }
 778
 779        text
 780    }
 781
 782    /// Serializes this thread into a format for storage or telemetry.
 783    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 784        let initial_project_snapshot = self.initial_project_snapshot.clone();
 785        cx.spawn(async move |this, cx| {
 786            let initial_project_snapshot = initial_project_snapshot.await;
 787            this.read_with(cx, |this, cx| SerializedThread {
 788                version: SerializedThread::VERSION.to_string(),
 789                summary: this.summary_or_default(),
 790                updated_at: this.updated_at(),
 791                messages: this
 792                    .messages()
 793                    .map(|message| SerializedMessage {
 794                        id: message.id,
 795                        role: message.role,
 796                        segments: message
 797                            .segments
 798                            .iter()
 799                            .map(|segment| match segment {
 800                                MessageSegment::Text(text) => {
 801                                    SerializedMessageSegment::Text { text: text.clone() }
 802                                }
 803                                MessageSegment::Thinking(text) => {
 804                                    SerializedMessageSegment::Thinking { text: text.clone() }
 805                                }
 806                            })
 807                            .collect(),
 808                        tool_uses: this
 809                            .tool_uses_for_message(message.id, cx)
 810                            .into_iter()
 811                            .map(|tool_use| SerializedToolUse {
 812                                id: tool_use.id,
 813                                name: tool_use.name,
 814                                input: tool_use.input,
 815                            })
 816                            .collect(),
 817                        tool_results: this
 818                            .tool_results_for_message(message.id)
 819                            .into_iter()
 820                            .map(|tool_result| SerializedToolResult {
 821                                tool_use_id: tool_result.tool_use_id.clone(),
 822                                is_error: tool_result.is_error,
 823                                content: tool_result.content.clone(),
 824                            })
 825                            .collect(),
 826                        context: message.context.clone(),
 827                    })
 828                    .collect(),
 829                initial_project_snapshot,
 830                cumulative_token_usage: this.cumulative_token_usage.clone(),
 831                detailed_summary_state: this.detailed_summary_state.clone(),
 832                exceeded_window_error: this.exceeded_window_error.clone(),
 833            })
 834        })
 835    }
 836
 837    pub fn send_to_model(
 838        &mut self,
 839        model: Arc<dyn LanguageModel>,
 840        request_kind: RequestKind,
 841        cx: &mut Context<Self>,
 842    ) {
 843        let mut request = self.to_completion_request(request_kind, cx);
 844        if model.supports_tools() {
 845            request.tools = {
 846                let mut tools = Vec::new();
 847                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 848                    LanguageModelRequestTool {
 849                        name: tool.name(),
 850                        description: tool.description(),
 851                        input_schema: tool.input_schema(model.tool_input_format()),
 852                    }
 853                }));
 854
 855                tools
 856            };
 857        }
 858
 859        self.stream_completion(request, model, cx);
 860    }
 861
 862    pub fn used_tools_since_last_user_message(&self) -> bool {
 863        for message in self.messages.iter().rev() {
 864            if self.tool_use.message_has_tool_results(message.id) {
 865                return true;
 866            } else if message.role == Role::User {
 867                return false;
 868            }
 869        }
 870
 871        false
 872    }
 873
 874    pub fn to_completion_request(
 875        &self,
 876        request_kind: RequestKind,
 877        cx: &App,
 878    ) -> LanguageModelRequest {
 879        let mut request = LanguageModelRequest {
 880            messages: vec![],
 881            tools: Vec::new(),
 882            stop: Vec::new(),
 883            temperature: None,
 884        };
 885
 886        if let Some(project_context) = self.project_context.borrow().as_ref() {
 887            if let Some(system_prompt) = self
 888                .prompt_builder
 889                .generate_assistant_system_prompt(project_context)
 890                .context("failed to generate assistant system prompt")
 891                .log_err()
 892            {
 893                request.messages.push(LanguageModelRequestMessage {
 894                    role: Role::System,
 895                    content: vec![MessageContent::Text(system_prompt)],
 896                    cache: true,
 897                });
 898            }
 899        } else {
 900            log::error!("project_context not set.")
 901        }
 902
 903        for message in &self.messages {
 904            let mut request_message = LanguageModelRequestMessage {
 905                role: message.role,
 906                content: Vec::new(),
 907                cache: false,
 908            };
 909
 910            match request_kind {
 911                RequestKind::Chat => {
 912                    self.tool_use
 913                        .attach_tool_results(message.id, &mut request_message);
 914                }
 915                RequestKind::Summarize => {
 916                    // We don't care about tool use during summarization.
 917                    if self.tool_use.message_has_tool_results(message.id) {
 918                        continue;
 919                    }
 920                }
 921            }
 922
 923            if !message.segments.is_empty() {
 924                request_message
 925                    .content
 926                    .push(MessageContent::Text(message.to_string()));
 927            }
 928
 929            match request_kind {
 930                RequestKind::Chat => {
 931                    self.tool_use
 932                        .attach_tool_uses(message.id, &mut request_message);
 933                }
 934                RequestKind::Summarize => {
 935                    // We don't care about tool use during summarization.
 936                }
 937            };
 938
 939            request.messages.push(request_message);
 940        }
 941
 942        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 943        if let Some(last) = request.messages.last_mut() {
 944            last.cache = true;
 945        }
 946
 947        self.attached_tracked_files_state(&mut request.messages, cx);
 948
 949        request
 950    }
 951
 952    fn attached_tracked_files_state(
 953        &self,
 954        messages: &mut Vec<LanguageModelRequestMessage>,
 955        cx: &App,
 956    ) {
 957        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 958
 959        let mut stale_message = String::new();
 960
 961        let action_log = self.action_log.read(cx);
 962
 963        for stale_file in action_log.stale_buffers(cx) {
 964            let Some(file) = stale_file.read(cx).file() else {
 965                continue;
 966            };
 967
 968            if stale_message.is_empty() {
 969                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
 970            }
 971
 972            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 973        }
 974
 975        let mut content = Vec::with_capacity(2);
 976
 977        if !stale_message.is_empty() {
 978            content.push(stale_message.into());
 979        }
 980
 981        if action_log.has_edited_files_since_project_diagnostics_check() {
 982            content.push(
 983                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 984                and fix all errors AND warnings you introduced! \
 985                DO NOT mention you're going to do this until you're done."
 986                    .into(),
 987            );
 988        }
 989
 990        if !content.is_empty() {
 991            let context_message = LanguageModelRequestMessage {
 992                role: Role::User,
 993                content,
 994                cache: false,
 995            };
 996
 997            messages.push(context_message);
 998        }
 999    }
1000
1001    pub fn stream_completion(
1002        &mut self,
1003        request: LanguageModelRequest,
1004        model: Arc<dyn LanguageModel>,
1005        cx: &mut Context<Self>,
1006    ) {
1007        let pending_completion_id = post_inc(&mut self.completion_count);
1008
1009        let task = cx.spawn(async move |thread, cx| {
1010            let stream = model.stream_completion(request, &cx);
1011            let initial_token_usage =
1012                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1013            let stream_completion = async {
1014                let mut events = stream.await?;
1015                let mut stop_reason = StopReason::EndTurn;
1016                let mut current_token_usage = TokenUsage::default();
1017
1018                while let Some(event) = events.next().await {
1019                    let event = event?;
1020
1021                    thread.update(cx, |thread, cx| {
1022                        match event {
1023                            LanguageModelCompletionEvent::StartMessage { .. } => {
1024                                thread.insert_message(
1025                                    Role::Assistant,
1026                                    vec![MessageSegment::Text(String::new())],
1027                                    cx,
1028                                );
1029                            }
1030                            LanguageModelCompletionEvent::Stop(reason) => {
1031                                stop_reason = reason;
1032                            }
1033                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1034                                thread.cumulative_token_usage =
1035                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1036                                        - current_token_usage.clone();
1037                                current_token_usage = token_usage;
1038                            }
1039                            LanguageModelCompletionEvent::Text(chunk) => {
1040                                if let Some(last_message) = thread.messages.last_mut() {
1041                                    if last_message.role == Role::Assistant {
1042                                        last_message.push_text(&chunk);
1043                                        cx.emit(ThreadEvent::StreamedAssistantText(
1044                                            last_message.id,
1045                                            chunk,
1046                                        ));
1047                                    } else {
1048                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1049                                        // of a new Assistant response.
1050                                        //
1051                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1052                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1053                                        thread.insert_message(
1054                                            Role::Assistant,
1055                                            vec![MessageSegment::Text(chunk.to_string())],
1056                                            cx,
1057                                        );
1058                                    };
1059                                }
1060                            }
1061                            LanguageModelCompletionEvent::Thinking(chunk) => {
1062                                if let Some(last_message) = thread.messages.last_mut() {
1063                                    if last_message.role == Role::Assistant {
1064                                        last_message.push_thinking(&chunk);
1065                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1066                                            last_message.id,
1067                                            chunk,
1068                                        ));
1069                                    } else {
1070                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1071                                        // of a new Assistant response.
1072                                        //
1073                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1074                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1075                                        thread.insert_message(
1076                                            Role::Assistant,
1077                                            vec![MessageSegment::Thinking(chunk.to_string())],
1078                                            cx,
1079                                        );
1080                                    };
1081                                }
1082                            }
1083                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1084                                let last_assistant_message_id = thread
1085                                    .messages
1086                                    .iter_mut()
1087                                    .rfind(|message| message.role == Role::Assistant)
1088                                    .map(|message| message.id)
1089                                    .unwrap_or_else(|| {
1090                                        thread.insert_message(Role::Assistant, vec![], cx)
1091                                    });
1092
1093                                thread.tool_use.request_tool_use(
1094                                    last_assistant_message_id,
1095                                    tool_use,
1096                                    cx,
1097                                );
1098                            }
1099                        }
1100
1101                        thread.touch_updated_at();
1102                        cx.emit(ThreadEvent::StreamedCompletion);
1103                        cx.notify();
1104
1105                        thread.auto_capture_telemetry(cx);
1106                    })?;
1107
1108                    smol::future::yield_now().await;
1109                }
1110
1111                thread.update(cx, |thread, cx| {
1112                    thread
1113                        .pending_completions
1114                        .retain(|completion| completion.id != pending_completion_id);
1115
1116                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1117                        thread.summarize(cx);
1118                    }
1119                })?;
1120
1121                anyhow::Ok(stop_reason)
1122            };
1123
1124            let result = stream_completion.await;
1125
1126            thread
1127                .update(cx, |thread, cx| {
1128                    thread.finalize_pending_checkpoint(cx);
1129                    match result.as_ref() {
1130                        Ok(stop_reason) => match stop_reason {
1131                            StopReason::ToolUse => {
1132                                let tool_uses = thread.use_pending_tools(cx);
1133                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1134                            }
1135                            StopReason::EndTurn => {}
1136                            StopReason::MaxTokens => {}
1137                        },
1138                        Err(error) => {
1139                            if error.is::<PaymentRequiredError>() {
1140                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1141                            } else if error.is::<MaxMonthlySpendReachedError>() {
1142                                cx.emit(ThreadEvent::ShowError(
1143                                    ThreadError::MaxMonthlySpendReached,
1144                                ));
1145                            } else if let Some(known_error) =
1146                                error.downcast_ref::<LanguageModelKnownError>()
1147                            {
1148                                match known_error {
1149                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1150                                        tokens,
1151                                    } => {
1152                                        thread.exceeded_window_error = Some(ExceededWindowError {
1153                                            model_id: model.id(),
1154                                            token_count: *tokens,
1155                                        });
1156                                        cx.notify();
1157                                    }
1158                                }
1159                            } else {
1160                                let error_message = error
1161                                    .chain()
1162                                    .map(|err| err.to_string())
1163                                    .collect::<Vec<_>>()
1164                                    .join("\n");
1165                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1166                                    header: "Error interacting with language model".into(),
1167                                    message: SharedString::from(error_message.clone()),
1168                                }));
1169                            }
1170
1171                            thread.cancel_last_completion(cx);
1172                        }
1173                    }
1174                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1175
1176                    thread.auto_capture_telemetry(cx);
1177
1178                    if let Ok(initial_usage) = initial_token_usage {
1179                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1180
1181                        telemetry::event!(
1182                            "Assistant Thread Completion",
1183                            thread_id = thread.id().to_string(),
1184                            model = model.telemetry_id(),
1185                            model_provider = model.provider_id().to_string(),
1186                            input_tokens = usage.input_tokens,
1187                            output_tokens = usage.output_tokens,
1188                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1189                            cache_read_input_tokens = usage.cache_read_input_tokens,
1190                        );
1191                    }
1192                })
1193                .ok();
1194        });
1195
1196        self.pending_completions.push(PendingCompletion {
1197            id: pending_completion_id,
1198            _task: task,
1199        });
1200    }
1201
1202    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1203        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1204            return;
1205        };
1206
1207        if !model.provider.is_authenticated(cx) {
1208            return;
1209        }
1210
1211        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1212        request.messages.push(LanguageModelRequestMessage {
1213            role: Role::User,
1214            content: vec![
1215                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1216                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1217                 If the conversation is about a specific subject, include it in the title. \
1218                 Be descriptive. DO NOT speak in the first person."
1219                    .into(),
1220            ],
1221            cache: false,
1222        });
1223
1224        self.pending_summary = cx.spawn(async move |this, cx| {
1225            async move {
1226                let stream = model.model.stream_completion_text(request, &cx);
1227                let mut messages = stream.await?;
1228
1229                let mut new_summary = String::new();
1230                while let Some(message) = messages.stream.next().await {
1231                    let text = message?;
1232                    let mut lines = text.lines();
1233                    new_summary.extend(lines.next());
1234
1235                    // Stop if the LLM generated multiple lines.
1236                    if lines.next().is_some() {
1237                        break;
1238                    }
1239                }
1240
1241                this.update(cx, |this, cx| {
1242                    if !new_summary.is_empty() {
1243                        this.summary = Some(new_summary.into());
1244                    }
1245
1246                    cx.emit(ThreadEvent::SummaryGenerated);
1247                })?;
1248
1249                anyhow::Ok(())
1250            }
1251            .log_err()
1252            .await
1253        });
1254    }
1255
1256    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1257        let last_message_id = self.messages.last().map(|message| message.id)?;
1258
1259        match &self.detailed_summary_state {
1260            DetailedSummaryState::Generating { message_id, .. }
1261            | DetailedSummaryState::Generated { message_id, .. }
1262                if *message_id == last_message_id =>
1263            {
1264                // Already up-to-date
1265                return None;
1266            }
1267            _ => {}
1268        }
1269
1270        let ConfiguredModel { model, provider } =
1271            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1272
1273        if !provider.is_authenticated(cx) {
1274            return None;
1275        }
1276
1277        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1278
1279        request.messages.push(LanguageModelRequestMessage {
1280            role: Role::User,
1281            content: vec![
1282                "Generate a detailed summary of this conversation. Include:\n\
1283                1. A brief overview of what was discussed\n\
1284                2. Key facts or information discovered\n\
1285                3. Outcomes or conclusions reached\n\
1286                4. Any action items or next steps if any\n\
1287                Format it in Markdown with headings and bullet points."
1288                    .into(),
1289            ],
1290            cache: false,
1291        });
1292
1293        let task = cx.spawn(async move |thread, cx| {
1294            let stream = model.stream_completion_text(request, &cx);
1295            let Some(mut messages) = stream.await.log_err() else {
1296                thread
1297                    .update(cx, |this, _cx| {
1298                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1299                    })
1300                    .log_err();
1301
1302                return;
1303            };
1304
1305            let mut new_detailed_summary = String::new();
1306
1307            while let Some(chunk) = messages.stream.next().await {
1308                if let Some(chunk) = chunk.log_err() {
1309                    new_detailed_summary.push_str(&chunk);
1310                }
1311            }
1312
1313            thread
1314                .update(cx, |this, _cx| {
1315                    this.detailed_summary_state = DetailedSummaryState::Generated {
1316                        text: new_detailed_summary.into(),
1317                        message_id: last_message_id,
1318                    };
1319                })
1320                .log_err();
1321        });
1322
1323        self.detailed_summary_state = DetailedSummaryState::Generating {
1324            message_id: last_message_id,
1325        };
1326
1327        Some(task)
1328    }
1329
1330    pub fn is_generating_detailed_summary(&self) -> bool {
1331        matches!(
1332            self.detailed_summary_state,
1333            DetailedSummaryState::Generating { .. }
1334        )
1335    }
1336
1337    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1338        self.auto_capture_telemetry(cx);
1339        let request = self.to_completion_request(RequestKind::Chat, cx);
1340        let messages = Arc::new(request.messages);
1341        let pending_tool_uses = self
1342            .tool_use
1343            .pending_tool_uses()
1344            .into_iter()
1345            .filter(|tool_use| tool_use.status.is_idle())
1346            .cloned()
1347            .collect::<Vec<_>>();
1348
1349        for tool_use in pending_tool_uses.iter() {
1350            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1351                if tool.needs_confirmation(&tool_use.input, cx)
1352                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1353                {
1354                    self.tool_use.confirm_tool_use(
1355                        tool_use.id.clone(),
1356                        tool_use.ui_text.clone(),
1357                        tool_use.input.clone(),
1358                        messages.clone(),
1359                        tool,
1360                    );
1361                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1362                } else {
1363                    self.run_tool(
1364                        tool_use.id.clone(),
1365                        tool_use.ui_text.clone(),
1366                        tool_use.input.clone(),
1367                        &messages,
1368                        tool,
1369                        cx,
1370                    );
1371                }
1372            }
1373        }
1374
1375        pending_tool_uses
1376    }
1377
1378    pub fn run_tool(
1379        &mut self,
1380        tool_use_id: LanguageModelToolUseId,
1381        ui_text: impl Into<SharedString>,
1382        input: serde_json::Value,
1383        messages: &[LanguageModelRequestMessage],
1384        tool: Arc<dyn Tool>,
1385        cx: &mut Context<Thread>,
1386    ) {
1387        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1388        self.tool_use
1389            .run_pending_tool(tool_use_id, ui_text.into(), task);
1390    }
1391
1392    fn spawn_tool_use(
1393        &mut self,
1394        tool_use_id: LanguageModelToolUseId,
1395        messages: &[LanguageModelRequestMessage],
1396        input: serde_json::Value,
1397        tool: Arc<dyn Tool>,
1398        cx: &mut Context<Thread>,
1399    ) -> Task<()> {
1400        let tool_name: Arc<str> = tool.name().into();
1401
1402        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1403            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1404        } else {
1405            tool.run(
1406                input,
1407                messages,
1408                self.project.clone(),
1409                self.action_log.clone(),
1410                cx,
1411            )
1412        };
1413
1414        cx.spawn({
1415            async move |thread: WeakEntity<Thread>, cx| {
1416                let output = run_tool.await;
1417
1418                thread
1419                    .update(cx, |thread, cx| {
1420                        let pending_tool_use = thread.tool_use.insert_tool_output(
1421                            tool_use_id.clone(),
1422                            tool_name,
1423                            output,
1424                            cx,
1425                        );
1426                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1427                    })
1428                    .ok();
1429            }
1430        })
1431    }
1432
1433    fn tool_finished(
1434        &mut self,
1435        tool_use_id: LanguageModelToolUseId,
1436        pending_tool_use: Option<PendingToolUse>,
1437        canceled: bool,
1438        cx: &mut Context<Self>,
1439    ) {
1440        if self.all_tools_finished() {
1441            let model_registry = LanguageModelRegistry::read_global(cx);
1442            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1443                self.attach_tool_results(cx);
1444                if !canceled {
1445                    self.send_to_model(model, RequestKind::Chat, cx);
1446                }
1447            }
1448        }
1449
1450        cx.emit(ThreadEvent::ToolFinished {
1451            tool_use_id,
1452            pending_tool_use,
1453        });
1454    }
1455
1456    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1457        // Insert a user message to contain the tool results.
1458        self.insert_user_message(
1459            // TODO: Sending up a user message without any content results in the model sending back
1460            // responses that also don't have any content. We currently don't handle this case well,
1461            // so for now we provide some text to keep the model on track.
1462            "Here are the tool results.",
1463            Vec::new(),
1464            None,
1465            cx,
1466        );
1467    }
1468
1469    /// Cancels the last pending completion, if there are any pending.
1470    ///
1471    /// Returns whether a completion was canceled.
1472    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1473        let canceled = if self.pending_completions.pop().is_some() {
1474            true
1475        } else {
1476            let mut canceled = false;
1477            for pending_tool_use in self.tool_use.cancel_pending() {
1478                canceled = true;
1479                self.tool_finished(
1480                    pending_tool_use.id.clone(),
1481                    Some(pending_tool_use),
1482                    true,
1483                    cx,
1484                );
1485            }
1486            canceled
1487        };
1488        self.finalize_pending_checkpoint(cx);
1489        canceled
1490    }
1491
1492    pub fn feedback(&self) -> Option<ThreadFeedback> {
1493        self.feedback
1494    }
1495
1496    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1497        self.message_feedback.get(&message_id).copied()
1498    }
1499
1500    pub fn report_message_feedback(
1501        &mut self,
1502        message_id: MessageId,
1503        feedback: ThreadFeedback,
1504        cx: &mut Context<Self>,
1505    ) -> Task<Result<()>> {
1506        if self.message_feedback.get(&message_id) == Some(&feedback) {
1507            return Task::ready(Ok(()));
1508        }
1509
1510        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1511        let serialized_thread = self.serialize(cx);
1512        let thread_id = self.id().clone();
1513        let client = self.project.read(cx).client();
1514
1515        let enabled_tool_names: Vec<String> = self
1516            .tools()
1517            .enabled_tools(cx)
1518            .iter()
1519            .map(|tool| tool.name().to_string())
1520            .collect();
1521
1522        self.message_feedback.insert(message_id, feedback);
1523
1524        cx.notify();
1525
1526        let message_content = self
1527            .message(message_id)
1528            .map(|msg| msg.to_string())
1529            .unwrap_or_default();
1530
1531        cx.background_spawn(async move {
1532            let final_project_snapshot = final_project_snapshot.await;
1533            let serialized_thread = serialized_thread.await?;
1534            let thread_data =
1535                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1536
1537            let rating = match feedback {
1538                ThreadFeedback::Positive => "positive",
1539                ThreadFeedback::Negative => "negative",
1540            };
1541            telemetry::event!(
1542                "Assistant Thread Rated",
1543                rating,
1544                thread_id,
1545                enabled_tool_names,
1546                message_id = message_id.0,
1547                message_content,
1548                thread_data,
1549                final_project_snapshot
1550            );
1551            client.telemetry().flush_events();
1552
1553            Ok(())
1554        })
1555    }
1556
1557    pub fn report_feedback(
1558        &mut self,
1559        feedback: ThreadFeedback,
1560        cx: &mut Context<Self>,
1561    ) -> Task<Result<()>> {
1562        let last_assistant_message_id = self
1563            .messages
1564            .iter()
1565            .rev()
1566            .find(|msg| msg.role == Role::Assistant)
1567            .map(|msg| msg.id);
1568
1569        if let Some(message_id) = last_assistant_message_id {
1570            self.report_message_feedback(message_id, feedback, cx)
1571        } else {
1572            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1573            let serialized_thread = self.serialize(cx);
1574            let thread_id = self.id().clone();
1575            let client = self.project.read(cx).client();
1576            self.feedback = Some(feedback);
1577            cx.notify();
1578
1579            cx.background_spawn(async move {
1580                let final_project_snapshot = final_project_snapshot.await;
1581                let serialized_thread = serialized_thread.await?;
1582                let thread_data = serde_json::to_value(serialized_thread)
1583                    .unwrap_or_else(|_| serde_json::Value::Null);
1584
1585                let rating = match feedback {
1586                    ThreadFeedback::Positive => "positive",
1587                    ThreadFeedback::Negative => "negative",
1588                };
1589                telemetry::event!(
1590                    "Assistant Thread Rated",
1591                    rating,
1592                    thread_id,
1593                    thread_data,
1594                    final_project_snapshot
1595                );
1596                client.telemetry().flush_events();
1597
1598                Ok(())
1599            })
1600        }
1601    }
1602
1603    /// Create a snapshot of the current project state including git information and unsaved buffers.
1604    fn project_snapshot(
1605        project: Entity<Project>,
1606        cx: &mut Context<Self>,
1607    ) -> Task<Arc<ProjectSnapshot>> {
1608        let git_store = project.read(cx).git_store().clone();
1609        let worktree_snapshots: Vec<_> = project
1610            .read(cx)
1611            .visible_worktrees(cx)
1612            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1613            .collect();
1614
1615        cx.spawn(async move |_, cx| {
1616            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1617
1618            let mut unsaved_buffers = Vec::new();
1619            cx.update(|app_cx| {
1620                let buffer_store = project.read(app_cx).buffer_store();
1621                for buffer_handle in buffer_store.read(app_cx).buffers() {
1622                    let buffer = buffer_handle.read(app_cx);
1623                    if buffer.is_dirty() {
1624                        if let Some(file) = buffer.file() {
1625                            let path = file.path().to_string_lossy().to_string();
1626                            unsaved_buffers.push(path);
1627                        }
1628                    }
1629                }
1630            })
1631            .ok();
1632
1633            Arc::new(ProjectSnapshot {
1634                worktree_snapshots,
1635                unsaved_buffer_paths: unsaved_buffers,
1636                timestamp: Utc::now(),
1637            })
1638        })
1639    }
1640
1641    fn worktree_snapshot(
1642        worktree: Entity<project::Worktree>,
1643        git_store: Entity<GitStore>,
1644        cx: &App,
1645    ) -> Task<WorktreeSnapshot> {
1646        cx.spawn(async move |cx| {
1647            // Get worktree path and snapshot
1648            let worktree_info = cx.update(|app_cx| {
1649                let worktree = worktree.read(app_cx);
1650                let path = worktree.abs_path().to_string_lossy().to_string();
1651                let snapshot = worktree.snapshot();
1652                (path, snapshot)
1653            });
1654
1655            let Ok((worktree_path, _snapshot)) = worktree_info else {
1656                return WorktreeSnapshot {
1657                    worktree_path: String::new(),
1658                    git_state: None,
1659                };
1660            };
1661
1662            let git_state = git_store
1663                .update(cx, |git_store, cx| {
1664                    git_store
1665                        .repositories()
1666                        .values()
1667                        .find(|repo| {
1668                            repo.read(cx)
1669                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1670                                .is_some()
1671                        })
1672                        .cloned()
1673                })
1674                .ok()
1675                .flatten()
1676                .map(|repo| {
1677                    repo.update(cx, |repo, _| {
1678                        let current_branch =
1679                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1680                        repo.send_job(None, |state, _| async move {
1681                            let RepositoryState::Local { backend, .. } = state else {
1682                                return GitState {
1683                                    remote_url: None,
1684                                    head_sha: None,
1685                                    current_branch,
1686                                    diff: None,
1687                                };
1688                            };
1689
1690                            let remote_url = backend.remote_url("origin");
1691                            let head_sha = backend.head_sha();
1692                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1693
1694                            GitState {
1695                                remote_url,
1696                                head_sha,
1697                                current_branch,
1698                                diff,
1699                            }
1700                        })
1701                    })
1702                });
1703
1704            let git_state = match git_state {
1705                Some(git_state) => match git_state.ok() {
1706                    Some(git_state) => git_state.await.ok(),
1707                    None => None,
1708                },
1709                None => None,
1710            };
1711
1712            WorktreeSnapshot {
1713                worktree_path,
1714                git_state,
1715            }
1716        })
1717    }
1718
1719    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1720        let mut markdown = Vec::new();
1721
1722        if let Some(summary) = self.summary() {
1723            writeln!(markdown, "# {summary}\n")?;
1724        };
1725
1726        for message in self.messages() {
1727            writeln!(
1728                markdown,
1729                "## {role}\n",
1730                role = match message.role {
1731                    Role::User => "User",
1732                    Role::Assistant => "Assistant",
1733                    Role::System => "System",
1734                }
1735            )?;
1736
1737            if !message.context.is_empty() {
1738                writeln!(markdown, "{}", message.context)?;
1739            }
1740
1741            for segment in &message.segments {
1742                match segment {
1743                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1744                    MessageSegment::Thinking(text) => {
1745                        writeln!(markdown, "<think>{}</think>\n", text)?
1746                    }
1747                }
1748            }
1749
1750            for tool_use in self.tool_uses_for_message(message.id, cx) {
1751                writeln!(
1752                    markdown,
1753                    "**Use Tool: {} ({})**",
1754                    tool_use.name, tool_use.id
1755                )?;
1756                writeln!(markdown, "```json")?;
1757                writeln!(
1758                    markdown,
1759                    "{}",
1760                    serde_json::to_string_pretty(&tool_use.input)?
1761                )?;
1762                writeln!(markdown, "```")?;
1763            }
1764
1765            for tool_result in self.tool_results_for_message(message.id) {
1766                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1767                if tool_result.is_error {
1768                    write!(markdown, " (Error)")?;
1769                }
1770
1771                writeln!(markdown, "**\n")?;
1772                writeln!(markdown, "{}", tool_result.content)?;
1773            }
1774        }
1775
1776        Ok(String::from_utf8_lossy(&markdown).to_string())
1777    }
1778
1779    pub fn keep_edits_in_range(
1780        &mut self,
1781        buffer: Entity<language::Buffer>,
1782        buffer_range: Range<language::Anchor>,
1783        cx: &mut Context<Self>,
1784    ) {
1785        self.action_log.update(cx, |action_log, cx| {
1786            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1787        });
1788    }
1789
1790    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1791        self.action_log
1792            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1793    }
1794
1795    pub fn reject_edits_in_range(
1796        &mut self,
1797        buffer: Entity<language::Buffer>,
1798        buffer_range: Range<language::Anchor>,
1799        cx: &mut Context<Self>,
1800    ) -> Task<Result<()>> {
1801        self.action_log.update(cx, |action_log, cx| {
1802            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1803        })
1804    }
1805
1806    pub fn action_log(&self) -> &Entity<ActionLog> {
1807        &self.action_log
1808    }
1809
1810    pub fn project(&self) -> &Entity<Project> {
1811        &self.project
1812    }
1813
1814    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1815        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1816            return;
1817        }
1818
1819        let now = Instant::now();
1820        if let Some(last) = self.last_auto_capture_at {
1821            if now.duration_since(last).as_secs() < 10 {
1822                return;
1823            }
1824        }
1825
1826        self.last_auto_capture_at = Some(now);
1827
1828        let thread_id = self.id().clone();
1829        let github_login = self
1830            .project
1831            .read(cx)
1832            .user_store()
1833            .read(cx)
1834            .current_user()
1835            .map(|user| user.github_login.clone());
1836        let client = self.project.read(cx).client().clone();
1837        let serialize_task = self.serialize(cx);
1838
1839        cx.background_executor()
1840            .spawn(async move {
1841                if let Ok(serialized_thread) = serialize_task.await {
1842                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1843                        telemetry::event!(
1844                            "Agent Thread Auto-Captured",
1845                            thread_id = thread_id.to_string(),
1846                            thread_data = thread_data,
1847                            auto_capture_reason = "tracked_user",
1848                            github_login = github_login
1849                        );
1850
1851                        client.telemetry().flush_events();
1852                    }
1853                }
1854            })
1855            .detach();
1856    }
1857
1858    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1859        let model_registry = LanguageModelRegistry::read_global(cx);
1860        let Some(model) = model_registry.default_model() else {
1861            return TotalTokenUsage::default();
1862        };
1863
1864        let max = model.model.max_token_count();
1865
1866        if let Some(exceeded_error) = &self.exceeded_window_error {
1867            if model.model.id() == exceeded_error.model_id {
1868                return TotalTokenUsage {
1869                    total: exceeded_error.token_count,
1870                    max,
1871                    ratio: TokenUsageRatio::Exceeded,
1872                };
1873            }
1874        }
1875
1876        #[cfg(debug_assertions)]
1877        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1878            .unwrap_or("0.8".to_string())
1879            .parse()
1880            .unwrap();
1881        #[cfg(not(debug_assertions))]
1882        let warning_threshold: f32 = 0.8;
1883
1884        let total = self.cumulative_token_usage.total_tokens() as usize;
1885
1886        let ratio = if total >= max {
1887            TokenUsageRatio::Exceeded
1888        } else if total as f32 / max as f32 >= warning_threshold {
1889            TokenUsageRatio::Warning
1890        } else {
1891            TokenUsageRatio::Normal
1892        };
1893
1894        TotalTokenUsage { total, max, ratio }
1895    }
1896
1897    pub fn deny_tool_use(
1898        &mut self,
1899        tool_use_id: LanguageModelToolUseId,
1900        tool_name: Arc<str>,
1901        cx: &mut Context<Self>,
1902    ) {
1903        let err = Err(anyhow::anyhow!(
1904            "Permission to run tool action denied by user"
1905        ));
1906
1907        self.tool_use
1908            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1909        self.tool_finished(tool_use_id.clone(), None, true, cx);
1910    }
1911}
1912
1913#[derive(Debug, Clone, Error)]
1914pub enum ThreadError {
1915    #[error("Payment required")]
1916    PaymentRequired,
1917    #[error("Max monthly spend reached")]
1918    MaxMonthlySpendReached,
1919    #[error("Message {header}: {message}")]
1920    Message {
1921        header: SharedString,
1922        message: SharedString,
1923    },
1924}
1925
1926#[derive(Debug, Clone)]
1927pub enum ThreadEvent {
1928    ShowError(ThreadError),
1929    StreamedCompletion,
1930    StreamedAssistantText(MessageId, String),
1931    StreamedAssistantThinking(MessageId, String),
1932    Stopped(Result<StopReason, Arc<anyhow::Error>>),
1933    MessageAdded(MessageId),
1934    MessageEdited(MessageId),
1935    MessageDeleted(MessageId),
1936    SummaryGenerated,
1937    SummaryChanged,
1938    UsePendingTools {
1939        tool_uses: Vec<PendingToolUse>,
1940    },
1941    ToolFinished {
1942        #[allow(unused)]
1943        tool_use_id: LanguageModelToolUseId,
1944        /// The pending tool use that corresponds to this tool.
1945        pending_tool_use: Option<PendingToolUse>,
1946    },
1947    CheckpointChanged,
1948    ToolConfirmationNeeded,
1949}
1950
1951impl EventEmitter<ThreadEvent> for Thread {}
1952
1953struct PendingCompletion {
1954    id: usize,
1955    _task: Task<()>,
1956}
1957
1958#[cfg(test)]
1959mod tests {
1960    use super::*;
1961    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1962    use assistant_settings::AssistantSettings;
1963    use context_server::ContextServerSettings;
1964    use editor::EditorSettings;
1965    use gpui::TestAppContext;
1966    use project::{FakeFs, Project};
1967    use prompt_store::PromptBuilder;
1968    use serde_json::json;
1969    use settings::{Settings, SettingsStore};
1970    use std::sync::Arc;
1971    use theme::ThemeSettings;
1972    use util::path;
1973    use workspace::Workspace;
1974
1975    #[gpui::test]
1976    async fn test_message_with_context(cx: &mut TestAppContext) {
1977        init_test_settings(cx);
1978
1979        let project = create_test_project(
1980            cx,
1981            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1982        )
1983        .await;
1984
1985        let (_workspace, _thread_store, thread, context_store) =
1986            setup_test_environment(cx, project.clone()).await;
1987
1988        add_file_to_context(&project, &context_store, "test/code.rs", cx)
1989            .await
1990            .unwrap();
1991
1992        let context =
1993            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
1994
1995        // Insert user message with context
1996        let message_id = thread.update(cx, |thread, cx| {
1997            thread.insert_user_message("Please explain this code", vec![context], None, cx)
1998        });
1999
2000        // Check content and context in message object
2001        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2002
2003        // Use different path format strings based on platform for the test
2004        #[cfg(windows)]
2005        let path_part = r"test\code.rs";
2006        #[cfg(not(windows))]
2007        let path_part = "test/code.rs";
2008
2009        let expected_context = format!(
2010            r#"
2011<context>
2012The following items were attached by the user. You don't need to use other tools to read them.
2013
2014<files>
2015```rs {path_part}
2016fn main() {{
2017    println!("Hello, world!");
2018}}
2019```
2020</files>
2021</context>
2022"#
2023        );
2024
2025        assert_eq!(message.role, Role::User);
2026        assert_eq!(message.segments.len(), 1);
2027        assert_eq!(
2028            message.segments[0],
2029            MessageSegment::Text("Please explain this code".to_string())
2030        );
2031        assert_eq!(message.context, expected_context);
2032
2033        // Check message in request
2034        let request = thread.read_with(cx, |thread, cx| {
2035            thread.to_completion_request(RequestKind::Chat, cx)
2036        });
2037
2038        assert_eq!(request.messages.len(), 2);
2039        let expected_full_message = format!("{}Please explain this code", expected_context);
2040        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2041    }
2042
2043    #[gpui::test]
2044    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2045        init_test_settings(cx);
2046
2047        let project = create_test_project(
2048            cx,
2049            json!({
2050                "file1.rs": "fn function1() {}\n",
2051                "file2.rs": "fn function2() {}\n",
2052                "file3.rs": "fn function3() {}\n",
2053            }),
2054        )
2055        .await;
2056
2057        let (_, _thread_store, thread, context_store) =
2058            setup_test_environment(cx, project.clone()).await;
2059
2060        // Open files individually
2061        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2062            .await
2063            .unwrap();
2064        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2065            .await
2066            .unwrap();
2067        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2068            .await
2069            .unwrap();
2070
2071        // Get the context objects
2072        let contexts = context_store.update(cx, |store, _| store.context().clone());
2073        assert_eq!(contexts.len(), 3);
2074
2075        // First message with context 1
2076        let message1_id = thread.update(cx, |thread, cx| {
2077            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2078        });
2079
2080        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2081        let message2_id = thread.update(cx, |thread, cx| {
2082            thread.insert_user_message(
2083                "Message 2",
2084                vec![contexts[0].clone(), contexts[1].clone()],
2085                None,
2086                cx,
2087            )
2088        });
2089
2090        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2091        let message3_id = thread.update(cx, |thread, cx| {
2092            thread.insert_user_message(
2093                "Message 3",
2094                vec![
2095                    contexts[0].clone(),
2096                    contexts[1].clone(),
2097                    contexts[2].clone(),
2098                ],
2099                None,
2100                cx,
2101            )
2102        });
2103
2104        // Check what contexts are included in each message
2105        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2106            (
2107                thread.message(message1_id).unwrap().clone(),
2108                thread.message(message2_id).unwrap().clone(),
2109                thread.message(message3_id).unwrap().clone(),
2110            )
2111        });
2112
2113        // First message should include context 1
2114        assert!(message1.context.contains("file1.rs"));
2115
2116        // Second message should include only context 2 (not 1)
2117        assert!(!message2.context.contains("file1.rs"));
2118        assert!(message2.context.contains("file2.rs"));
2119
2120        // Third message should include only context 3 (not 1 or 2)
2121        assert!(!message3.context.contains("file1.rs"));
2122        assert!(!message3.context.contains("file2.rs"));
2123        assert!(message3.context.contains("file3.rs"));
2124
2125        // Check entire request to make sure all contexts are properly included
2126        let request = thread.read_with(cx, |thread, cx| {
2127            thread.to_completion_request(RequestKind::Chat, cx)
2128        });
2129
2130        // The request should contain all 3 messages
2131        assert_eq!(request.messages.len(), 4);
2132
2133        // Check that the contexts are properly formatted in each message
2134        assert!(request.messages[1].string_contents().contains("file1.rs"));
2135        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2136        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2137
2138        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2139        assert!(request.messages[2].string_contents().contains("file2.rs"));
2140        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2141
2142        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2143        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2144        assert!(request.messages[3].string_contents().contains("file3.rs"));
2145    }
2146
2147    #[gpui::test]
2148    async fn test_message_without_files(cx: &mut TestAppContext) {
2149        init_test_settings(cx);
2150
2151        let project = create_test_project(
2152            cx,
2153            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2154        )
2155        .await;
2156
2157        let (_, _thread_store, thread, _context_store) =
2158            setup_test_environment(cx, project.clone()).await;
2159
2160        // Insert user message without any context (empty context vector)
2161        let message_id = thread.update(cx, |thread, cx| {
2162            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2163        });
2164
2165        // Check content and context in message object
2166        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2167
2168        // Context should be empty when no files are included
2169        assert_eq!(message.role, Role::User);
2170        assert_eq!(message.segments.len(), 1);
2171        assert_eq!(
2172            message.segments[0],
2173            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2174        );
2175        assert_eq!(message.context, "");
2176
2177        // Check message in request
2178        let request = thread.read_with(cx, |thread, cx| {
2179            thread.to_completion_request(RequestKind::Chat, cx)
2180        });
2181
2182        assert_eq!(request.messages.len(), 2);
2183        assert_eq!(
2184            request.messages[1].string_contents(),
2185            "What is the best way to learn Rust?"
2186        );
2187
2188        // Add second message, also without context
2189        let message2_id = thread.update(cx, |thread, cx| {
2190            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2191        });
2192
2193        let message2 =
2194            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2195        assert_eq!(message2.context, "");
2196
2197        // Check that both messages appear in the request
2198        let request = thread.read_with(cx, |thread, cx| {
2199            thread.to_completion_request(RequestKind::Chat, cx)
2200        });
2201
2202        assert_eq!(request.messages.len(), 3);
2203        assert_eq!(
2204            request.messages[1].string_contents(),
2205            "What is the best way to learn Rust?"
2206        );
2207        assert_eq!(
2208            request.messages[2].string_contents(),
2209            "Are there any good books?"
2210        );
2211    }
2212
2213    #[gpui::test]
2214    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2215        init_test_settings(cx);
2216
2217        let project = create_test_project(
2218            cx,
2219            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2220        )
2221        .await;
2222
2223        let (_workspace, _thread_store, thread, context_store) =
2224            setup_test_environment(cx, project.clone()).await;
2225
2226        // Open buffer and add it to context
2227        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2228            .await
2229            .unwrap();
2230
2231        let context =
2232            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2233
2234        // Insert user message with the buffer as context
2235        thread.update(cx, |thread, cx| {
2236            thread.insert_user_message("Explain this code", vec![context], None, cx)
2237        });
2238
2239        // Create a request and check that it doesn't have a stale buffer warning yet
2240        let initial_request = thread.read_with(cx, |thread, cx| {
2241            thread.to_completion_request(RequestKind::Chat, cx)
2242        });
2243
2244        // Make sure we don't have a stale file warning yet
2245        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2246            msg.string_contents()
2247                .contains("These files changed since last read:")
2248        });
2249        assert!(
2250            !has_stale_warning,
2251            "Should not have stale buffer warning before buffer is modified"
2252        );
2253
2254        // Modify the buffer
2255        buffer.update(cx, |buffer, cx| {
2256            // Find a position at the end of line 1
2257            buffer.edit(
2258                [(1..1, "\n    println!(\"Added a new line\");\n")],
2259                None,
2260                cx,
2261            );
2262        });
2263
2264        // Insert another user message without context
2265        thread.update(cx, |thread, cx| {
2266            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2267        });
2268
2269        // Create a new request and check for the stale buffer warning
2270        let new_request = thread.read_with(cx, |thread, cx| {
2271            thread.to_completion_request(RequestKind::Chat, cx)
2272        });
2273
2274        // We should have a stale file warning as the last message
2275        let last_message = new_request
2276            .messages
2277            .last()
2278            .expect("Request should have messages");
2279
2280        // The last message should be the stale buffer notification
2281        assert_eq!(last_message.role, Role::User);
2282
2283        // Check the exact content of the message
2284        let expected_content = "These files changed since last read:\n- code.rs\n";
2285        assert_eq!(
2286            last_message.string_contents(),
2287            expected_content,
2288            "Last message should be exactly the stale buffer notification"
2289        );
2290    }
2291
2292    fn init_test_settings(cx: &mut TestAppContext) {
2293        cx.update(|cx| {
2294            let settings_store = SettingsStore::test(cx);
2295            cx.set_global(settings_store);
2296            language::init(cx);
2297            Project::init_settings(cx);
2298            AssistantSettings::register(cx);
2299            thread_store::init(cx);
2300            workspace::init_settings(cx);
2301            ThemeSettings::register(cx);
2302            ContextServerSettings::register(cx);
2303            EditorSettings::register(cx);
2304        });
2305    }
2306
2307    // Helper to create a test project with test files
2308    async fn create_test_project(
2309        cx: &mut TestAppContext,
2310        files: serde_json::Value,
2311    ) -> Entity<Project> {
2312        let fs = FakeFs::new(cx.executor());
2313        fs.insert_tree(path!("/test"), files).await;
2314        Project::test(fs, [path!("/test").as_ref()], cx).await
2315    }
2316
2317    async fn setup_test_environment(
2318        cx: &mut TestAppContext,
2319        project: Entity<Project>,
2320    ) -> (
2321        Entity<Workspace>,
2322        Entity<ThreadStore>,
2323        Entity<Thread>,
2324        Entity<ContextStore>,
2325    ) {
2326        let (workspace, cx) =
2327            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2328
2329        let thread_store = cx
2330            .update(|_, cx| {
2331                ThreadStore::load(
2332                    project.clone(),
2333                    Arc::default(),
2334                    Arc::new(PromptBuilder::new(None).unwrap()),
2335                    cx,
2336                )
2337            })
2338            .await;
2339
2340        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2341        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2342
2343        (workspace, thread_store, thread, context_store)
2344    }
2345
2346    async fn add_file_to_context(
2347        project: &Entity<Project>,
2348        context_store: &Entity<ContextStore>,
2349        path: &str,
2350        cx: &mut TestAppContext,
2351    ) -> Result<Entity<language::Buffer>> {
2352        let buffer_path = project
2353            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2354            .unwrap();
2355
2356        let buffer = project
2357            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2358            .await
2359            .unwrap();
2360
2361        context_store
2362            .update(cx, |store, cx| {
2363                store.add_file_from_buffer(buffer.clone(), cx)
2364            })
2365            .await?;
2366
2367        Ok(buffer)
2368    }
2369}