thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  22    Role, StopReason, TokenUsage,
  23};
  24use project::Project;
  25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  26use prompt_store::PromptBuilder;
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use thiserror::Error;
  31use util::{ResultExt as _, TryFutureExt as _, post_inc};
  32use uuid::Uuid;
  33
  34use crate::context::{AssistantContext, ContextId, format_context_as_string};
  35use crate::thread_store::{
  36    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  37    SerializedToolUse, SharedProjectContext,
  38};
  39use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  40
  41#[derive(Debug, Clone, Copy)]
  42pub enum RequestKind {
  43    Chat,
  44    /// Used when summarizing a thread.
  45    Summarize,
  46}
  47
  48#[derive(
  49    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  50)]
  51pub struct ThreadId(Arc<str>);
  52
  53impl ThreadId {
  54    pub fn new() -> Self {
  55        Self(Uuid::new_v4().to_string().into())
  56    }
  57}
  58
  59impl std::fmt::Display for ThreadId {
  60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  61        write!(f, "{}", self.0)
  62    }
  63}
  64
  65impl From<&str> for ThreadId {
  66    fn from(value: &str) -> Self {
  67        Self(value.into())
  68    }
  69}
  70
  71#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  72pub struct MessageId(pub(crate) usize);
  73
  74impl MessageId {
  75    fn post_inc(&mut self) -> Self {
  76        Self(post_inc(&mut self.0))
  77    }
  78}
  79
  80/// A message in a [`Thread`].
  81#[derive(Debug, Clone)]
  82pub struct Message {
  83    pub id: MessageId,
  84    pub role: Role,
  85    pub segments: Vec<MessageSegment>,
  86    pub context: String,
  87}
  88
  89impl Message {
  90    /// Returns whether the message contains any meaningful text that should be displayed
  91    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  92    pub fn should_display_content(&self) -> bool {
  93        self.segments.iter().all(|segment| segment.should_display())
  94    }
  95
  96    pub fn push_thinking(&mut self, text: &str) {
  97        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  98            segment.push_str(text);
  99        } else {
 100            self.segments
 101                .push(MessageSegment::Thinking(text.to_string()));
 102        }
 103    }
 104
 105    pub fn push_text(&mut self, text: &str) {
 106        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 107            segment.push_str(text);
 108        } else {
 109            self.segments.push(MessageSegment::Text(text.to_string()));
 110        }
 111    }
 112
 113    pub fn to_string(&self) -> String {
 114        let mut result = String::new();
 115
 116        if !self.context.is_empty() {
 117            result.push_str(&self.context);
 118        }
 119
 120        for segment in &self.segments {
 121            match segment {
 122                MessageSegment::Text(text) => result.push_str(text),
 123                MessageSegment::Thinking(text) => {
 124                    result.push_str("<think>");
 125                    result.push_str(text);
 126                    result.push_str("</think>");
 127                }
 128            }
 129        }
 130
 131        result
 132    }
 133}
 134
 135#[derive(Debug, Clone, PartialEq, Eq)]
 136pub enum MessageSegment {
 137    Text(String),
 138    Thinking(String),
 139}
 140
 141impl MessageSegment {
 142    pub fn text_mut(&mut self) -> &mut String {
 143        match self {
 144            Self::Text(text) => text,
 145            Self::Thinking(text) => text,
 146        }
 147    }
 148
 149    pub fn should_display(&self) -> bool {
 150        // We add USING_TOOL_MARKER when making a request that includes tool uses
 151        // without non-whitespace text around them, and this can cause the model
 152        // to mimic the pattern, so we consider those segments not displayable.
 153        match self {
 154            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 155            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 156        }
 157    }
 158}
 159
 160#[derive(Debug, Clone, Serialize, Deserialize)]
 161pub struct ProjectSnapshot {
 162    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 163    pub unsaved_buffer_paths: Vec<String>,
 164    pub timestamp: DateTime<Utc>,
 165}
 166
 167#[derive(Debug, Clone, Serialize, Deserialize)]
 168pub struct WorktreeSnapshot {
 169    pub worktree_path: String,
 170    pub git_state: Option<GitState>,
 171}
 172
 173#[derive(Debug, Clone, Serialize, Deserialize)]
 174pub struct GitState {
 175    pub remote_url: Option<String>,
 176    pub head_sha: Option<String>,
 177    pub current_branch: Option<String>,
 178    pub diff: Option<String>,
 179}
 180
 181#[derive(Clone)]
 182pub struct ThreadCheckpoint {
 183    message_id: MessageId,
 184    git_checkpoint: GitStoreCheckpoint,
 185}
 186
 187#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 188pub enum ThreadFeedback {
 189    Positive,
 190    Negative,
 191}
 192
 193pub enum LastRestoreCheckpoint {
 194    Pending {
 195        message_id: MessageId,
 196    },
 197    Error {
 198        message_id: MessageId,
 199        error: String,
 200    },
 201}
 202
 203impl LastRestoreCheckpoint {
 204    pub fn message_id(&self) -> MessageId {
 205        match self {
 206            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 207            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 208        }
 209    }
 210}
 211
 212#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 213pub enum DetailedSummaryState {
 214    #[default]
 215    NotGenerated,
 216    Generating {
 217        message_id: MessageId,
 218    },
 219    Generated {
 220        text: SharedString,
 221        message_id: MessageId,
 222    },
 223}
 224
 225#[derive(Default)]
 226pub struct TotalTokenUsage {
 227    pub total: usize,
 228    pub max: usize,
 229    pub ratio: TokenUsageRatio,
 230}
 231
 232#[derive(Debug, Default, PartialEq, Eq)]
 233pub enum TokenUsageRatio {
 234    #[default]
 235    Normal,
 236    Warning,
 237    Exceeded,
 238}
 239
 240/// A thread of conversation with the LLM.
 241pub struct Thread {
 242    id: ThreadId,
 243    updated_at: DateTime<Utc>,
 244    summary: Option<SharedString>,
 245    pending_summary: Task<Option<()>>,
 246    detailed_summary_state: DetailedSummaryState,
 247    messages: Vec<Message>,
 248    next_message_id: MessageId,
 249    context: BTreeMap<ContextId, AssistantContext>,
 250    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 251    project_context: SharedProjectContext,
 252    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 253    completion_count: usize,
 254    pending_completions: Vec<PendingCompletion>,
 255    project: Entity<Project>,
 256    prompt_builder: Arc<PromptBuilder>,
 257    tools: Entity<ToolWorkingSet>,
 258    tool_use: ToolUseState,
 259    action_log: Entity<ActionLog>,
 260    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 261    pending_checkpoint: Option<ThreadCheckpoint>,
 262    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 263    cumulative_token_usage: TokenUsage,
 264    exceeded_window_error: Option<ExceededWindowError>,
 265    feedback: Option<ThreadFeedback>,
 266    message_feedback: HashMap<MessageId, ThreadFeedback>,
 267    last_auto_capture_at: Option<Instant>,
 268}
 269
 270#[derive(Debug, Clone, Serialize, Deserialize)]
 271pub struct ExceededWindowError {
 272    /// Model used when last message exceeded context window
 273    model_id: LanguageModelId,
 274    /// Token count including last message
 275    token_count: usize,
 276}
 277
 278impl Thread {
 279    pub fn new(
 280        project: Entity<Project>,
 281        tools: Entity<ToolWorkingSet>,
 282        prompt_builder: Arc<PromptBuilder>,
 283        system_prompt: SharedProjectContext,
 284        cx: &mut Context<Self>,
 285    ) -> Self {
 286        Self {
 287            id: ThreadId::new(),
 288            updated_at: Utc::now(),
 289            summary: None,
 290            pending_summary: Task::ready(None),
 291            detailed_summary_state: DetailedSummaryState::NotGenerated,
 292            messages: Vec::new(),
 293            next_message_id: MessageId(0),
 294            context: BTreeMap::default(),
 295            context_by_message: HashMap::default(),
 296            project_context: system_prompt,
 297            checkpoints_by_message: HashMap::default(),
 298            completion_count: 0,
 299            pending_completions: Vec::new(),
 300            project: project.clone(),
 301            prompt_builder,
 302            tools: tools.clone(),
 303            last_restore_checkpoint: None,
 304            pending_checkpoint: None,
 305            tool_use: ToolUseState::new(tools.clone()),
 306            action_log: cx.new(|_| ActionLog::new(project.clone())),
 307            initial_project_snapshot: {
 308                let project_snapshot = Self::project_snapshot(project, cx);
 309                cx.foreground_executor()
 310                    .spawn(async move { Some(project_snapshot.await) })
 311                    .shared()
 312            },
 313            cumulative_token_usage: TokenUsage::default(),
 314            exceeded_window_error: None,
 315            feedback: None,
 316            message_feedback: HashMap::default(),
 317            last_auto_capture_at: None,
 318        }
 319    }
 320
 321    pub fn deserialize(
 322        id: ThreadId,
 323        serialized: SerializedThread,
 324        project: Entity<Project>,
 325        tools: Entity<ToolWorkingSet>,
 326        prompt_builder: Arc<PromptBuilder>,
 327        project_context: SharedProjectContext,
 328        cx: &mut Context<Self>,
 329    ) -> Self {
 330        let next_message_id = MessageId(
 331            serialized
 332                .messages
 333                .last()
 334                .map(|message| message.id.0 + 1)
 335                .unwrap_or(0),
 336        );
 337        let tool_use =
 338            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 339
 340        Self {
 341            id,
 342            updated_at: serialized.updated_at,
 343            summary: Some(serialized.summary),
 344            pending_summary: Task::ready(None),
 345            detailed_summary_state: serialized.detailed_summary_state,
 346            messages: serialized
 347                .messages
 348                .into_iter()
 349                .map(|message| Message {
 350                    id: message.id,
 351                    role: message.role,
 352                    segments: message
 353                        .segments
 354                        .into_iter()
 355                        .map(|segment| match segment {
 356                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 357                            SerializedMessageSegment::Thinking { text } => {
 358                                MessageSegment::Thinking(text)
 359                            }
 360                        })
 361                        .collect(),
 362                    context: message.context,
 363                })
 364                .collect(),
 365            next_message_id,
 366            context: BTreeMap::default(),
 367            context_by_message: HashMap::default(),
 368            project_context,
 369            checkpoints_by_message: HashMap::default(),
 370            completion_count: 0,
 371            pending_completions: Vec::new(),
 372            last_restore_checkpoint: None,
 373            pending_checkpoint: None,
 374            project: project.clone(),
 375            prompt_builder,
 376            tools,
 377            tool_use,
 378            action_log: cx.new(|_| ActionLog::new(project)),
 379            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 380            cumulative_token_usage: serialized.cumulative_token_usage,
 381            exceeded_window_error: None,
 382            feedback: None,
 383            message_feedback: HashMap::default(),
 384            last_auto_capture_at: None,
 385        }
 386    }
 387
 388    pub fn id(&self) -> &ThreadId {
 389        &self.id
 390    }
 391
 392    pub fn is_empty(&self) -> bool {
 393        self.messages.is_empty()
 394    }
 395
 396    pub fn updated_at(&self) -> DateTime<Utc> {
 397        self.updated_at
 398    }
 399
 400    pub fn touch_updated_at(&mut self) {
 401        self.updated_at = Utc::now();
 402    }
 403
 404    pub fn summary(&self) -> Option<SharedString> {
 405        self.summary.clone()
 406    }
 407
 408    pub fn project_context(&self) -> SharedProjectContext {
 409        self.project_context.clone()
 410    }
 411
 412    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 413
 414    pub fn summary_or_default(&self) -> SharedString {
 415        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 416    }
 417
 418    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 419        let Some(current_summary) = &self.summary else {
 420            // Don't allow setting summary until generated
 421            return;
 422        };
 423
 424        let mut new_summary = new_summary.into();
 425
 426        if new_summary.is_empty() {
 427            new_summary = Self::DEFAULT_SUMMARY;
 428        }
 429
 430        if current_summary != &new_summary {
 431            self.summary = Some(new_summary);
 432            cx.emit(ThreadEvent::SummaryChanged);
 433        }
 434    }
 435
 436    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 437        self.latest_detailed_summary()
 438            .unwrap_or_else(|| self.text().into())
 439    }
 440
 441    fn latest_detailed_summary(&self) -> Option<SharedString> {
 442        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 443            Some(text.clone())
 444        } else {
 445            None
 446        }
 447    }
 448
 449    pub fn message(&self, id: MessageId) -> Option<&Message> {
 450        self.messages.iter().find(|message| message.id == id)
 451    }
 452
 453    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 454        self.messages.iter()
 455    }
 456
 457    pub fn is_generating(&self) -> bool {
 458        !self.pending_completions.is_empty() || !self.all_tools_finished()
 459    }
 460
 461    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 462        &self.tools
 463    }
 464
 465    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 466        self.tool_use
 467            .pending_tool_uses()
 468            .into_iter()
 469            .find(|tool_use| &tool_use.id == id)
 470    }
 471
 472    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 473        self.tool_use
 474            .pending_tool_uses()
 475            .into_iter()
 476            .filter(|tool_use| tool_use.status.needs_confirmation())
 477    }
 478
 479    pub fn has_pending_tool_uses(&self) -> bool {
 480        !self.tool_use.pending_tool_uses().is_empty()
 481    }
 482
 483    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 484        self.checkpoints_by_message.get(&id).cloned()
 485    }
 486
 487    pub fn restore_checkpoint(
 488        &mut self,
 489        checkpoint: ThreadCheckpoint,
 490        cx: &mut Context<Self>,
 491    ) -> Task<Result<()>> {
 492        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 493            message_id: checkpoint.message_id,
 494        });
 495        cx.emit(ThreadEvent::CheckpointChanged);
 496        cx.notify();
 497
 498        let git_store = self.project().read(cx).git_store().clone();
 499        let restore = git_store.update(cx, |git_store, cx| {
 500            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 501        });
 502
 503        cx.spawn(async move |this, cx| {
 504            let result = restore.await;
 505            this.update(cx, |this, cx| {
 506                if let Err(err) = result.as_ref() {
 507                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 508                        message_id: checkpoint.message_id,
 509                        error: err.to_string(),
 510                    });
 511                } else {
 512                    this.truncate(checkpoint.message_id, cx);
 513                    this.last_restore_checkpoint = None;
 514                }
 515                this.pending_checkpoint = None;
 516                cx.emit(ThreadEvent::CheckpointChanged);
 517                cx.notify();
 518            })?;
 519            result
 520        })
 521    }
 522
 523    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 524        let pending_checkpoint = if self.is_generating() {
 525            return;
 526        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 527            checkpoint
 528        } else {
 529            return;
 530        };
 531
 532        let git_store = self.project.read(cx).git_store().clone();
 533        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 534        cx.spawn(async move |this, cx| match final_checkpoint.await {
 535            Ok(final_checkpoint) => {
 536                let equal = git_store
 537                    .update(cx, |store, cx| {
 538                        store.compare_checkpoints(
 539                            pending_checkpoint.git_checkpoint.clone(),
 540                            final_checkpoint.clone(),
 541                            cx,
 542                        )
 543                    })?
 544                    .await
 545                    .unwrap_or(false);
 546
 547                if equal {
 548                    git_store
 549                        .update(cx, |store, cx| {
 550                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 551                        })?
 552                        .detach();
 553                } else {
 554                    this.update(cx, |this, cx| {
 555                        this.insert_checkpoint(pending_checkpoint, cx)
 556                    })?;
 557                }
 558
 559                git_store
 560                    .update(cx, |store, cx| {
 561                        store.delete_checkpoint(final_checkpoint, cx)
 562                    })?
 563                    .detach();
 564
 565                Ok(())
 566            }
 567            Err(_) => this.update(cx, |this, cx| {
 568                this.insert_checkpoint(pending_checkpoint, cx)
 569            }),
 570        })
 571        .detach();
 572    }
 573
 574    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 575        self.checkpoints_by_message
 576            .insert(checkpoint.message_id, checkpoint);
 577        cx.emit(ThreadEvent::CheckpointChanged);
 578        cx.notify();
 579    }
 580
 581    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 582        self.last_restore_checkpoint.as_ref()
 583    }
 584
 585    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 586        let Some(message_ix) = self
 587            .messages
 588            .iter()
 589            .rposition(|message| message.id == message_id)
 590        else {
 591            return;
 592        };
 593        for deleted_message in self.messages.drain(message_ix..) {
 594            self.context_by_message.remove(&deleted_message.id);
 595            self.checkpoints_by_message.remove(&deleted_message.id);
 596        }
 597        cx.notify();
 598    }
 599
 600    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 601        self.context_by_message
 602            .get(&id)
 603            .into_iter()
 604            .flat_map(|context| {
 605                context
 606                    .iter()
 607                    .filter_map(|context_id| self.context.get(&context_id))
 608            })
 609    }
 610
 611    /// Returns whether all of the tool uses have finished running.
 612    pub fn all_tools_finished(&self) -> bool {
 613        // If the only pending tool uses left are the ones with errors, then
 614        // that means that we've finished running all of the pending tools.
 615        self.tool_use
 616            .pending_tool_uses()
 617            .iter()
 618            .all(|tool_use| tool_use.status.is_error())
 619    }
 620
 621    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 622        self.tool_use.tool_uses_for_message(id, cx)
 623    }
 624
 625    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 626        self.tool_use.tool_results_for_message(id)
 627    }
 628
 629    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 630        self.tool_use.tool_result(id)
 631    }
 632
 633    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 634        self.tool_use.message_has_tool_results(message_id)
 635    }
 636
 637    pub fn insert_user_message(
 638        &mut self,
 639        text: impl Into<String>,
 640        context: Vec<AssistantContext>,
 641        git_checkpoint: Option<GitStoreCheckpoint>,
 642        cx: &mut Context<Self>,
 643    ) -> MessageId {
 644        let text = text.into();
 645
 646        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 647
 648        // Filter out contexts that have already been included in previous messages
 649        let new_context: Vec<_> = context
 650            .into_iter()
 651            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 652            .collect();
 653
 654        if !new_context.is_empty() {
 655            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 656                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 657                    message.context = context_string;
 658                }
 659            }
 660
 661            self.action_log.update(cx, |log, cx| {
 662                // Track all buffers added as context
 663                for ctx in &new_context {
 664                    match ctx {
 665                        AssistantContext::File(file_ctx) => {
 666                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 667                        }
 668                        AssistantContext::Directory(dir_ctx) => {
 669                            for context_buffer in &dir_ctx.context_buffers {
 670                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 671                            }
 672                        }
 673                        AssistantContext::Symbol(symbol_ctx) => {
 674                            log.buffer_added_as_context(
 675                                symbol_ctx.context_symbol.buffer.clone(),
 676                                cx,
 677                            );
 678                        }
 679                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 680                    }
 681                }
 682            });
 683        }
 684
 685        let context_ids = new_context
 686            .iter()
 687            .map(|context| context.id())
 688            .collect::<Vec<_>>();
 689        self.context.extend(
 690            new_context
 691                .into_iter()
 692                .map(|context| (context.id(), context)),
 693        );
 694        self.context_by_message.insert(message_id, context_ids);
 695
 696        if let Some(git_checkpoint) = git_checkpoint {
 697            self.pending_checkpoint = Some(ThreadCheckpoint {
 698                message_id,
 699                git_checkpoint,
 700            });
 701        }
 702
 703        self.auto_capture_telemetry(cx);
 704
 705        message_id
 706    }
 707
 708    pub fn insert_message(
 709        &mut self,
 710        role: Role,
 711        segments: Vec<MessageSegment>,
 712        cx: &mut Context<Self>,
 713    ) -> MessageId {
 714        let id = self.next_message_id.post_inc();
 715        self.messages.push(Message {
 716            id,
 717            role,
 718            segments,
 719            context: String::new(),
 720        });
 721        self.touch_updated_at();
 722        cx.emit(ThreadEvent::MessageAdded(id));
 723        id
 724    }
 725
 726    pub fn edit_message(
 727        &mut self,
 728        id: MessageId,
 729        new_role: Role,
 730        new_segments: Vec<MessageSegment>,
 731        cx: &mut Context<Self>,
 732    ) -> bool {
 733        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 734            return false;
 735        };
 736        message.role = new_role;
 737        message.segments = new_segments;
 738        self.touch_updated_at();
 739        cx.emit(ThreadEvent::MessageEdited(id));
 740        true
 741    }
 742
 743    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 744        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 745            return false;
 746        };
 747        self.messages.remove(index);
 748        self.context_by_message.remove(&id);
 749        self.touch_updated_at();
 750        cx.emit(ThreadEvent::MessageDeleted(id));
 751        true
 752    }
 753
 754    /// Returns the representation of this [`Thread`] in a textual form.
 755    ///
 756    /// This is the representation we use when attaching a thread as context to another thread.
 757    pub fn text(&self) -> String {
 758        let mut text = String::new();
 759
 760        for message in &self.messages {
 761            text.push_str(match message.role {
 762                language_model::Role::User => "User:",
 763                language_model::Role::Assistant => "Assistant:",
 764                language_model::Role::System => "System:",
 765            });
 766            text.push('\n');
 767
 768            for segment in &message.segments {
 769                match segment {
 770                    MessageSegment::Text(content) => text.push_str(content),
 771                    MessageSegment::Thinking(content) => {
 772                        text.push_str(&format!("<think>{}</think>", content))
 773                    }
 774                }
 775            }
 776            text.push('\n');
 777        }
 778
 779        text
 780    }
 781
 782    /// Serializes this thread into a format for storage or telemetry.
 783    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 784        let initial_project_snapshot = self.initial_project_snapshot.clone();
 785        cx.spawn(async move |this, cx| {
 786            let initial_project_snapshot = initial_project_snapshot.await;
 787            this.read_with(cx, |this, cx| SerializedThread {
 788                version: SerializedThread::VERSION.to_string(),
 789                summary: this.summary_or_default(),
 790                updated_at: this.updated_at(),
 791                messages: this
 792                    .messages()
 793                    .map(|message| SerializedMessage {
 794                        id: message.id,
 795                        role: message.role,
 796                        segments: message
 797                            .segments
 798                            .iter()
 799                            .map(|segment| match segment {
 800                                MessageSegment::Text(text) => {
 801                                    SerializedMessageSegment::Text { text: text.clone() }
 802                                }
 803                                MessageSegment::Thinking(text) => {
 804                                    SerializedMessageSegment::Thinking { text: text.clone() }
 805                                }
 806                            })
 807                            .collect(),
 808                        tool_uses: this
 809                            .tool_uses_for_message(message.id, cx)
 810                            .into_iter()
 811                            .map(|tool_use| SerializedToolUse {
 812                                id: tool_use.id,
 813                                name: tool_use.name,
 814                                input: tool_use.input,
 815                            })
 816                            .collect(),
 817                        tool_results: this
 818                            .tool_results_for_message(message.id)
 819                            .into_iter()
 820                            .map(|tool_result| SerializedToolResult {
 821                                tool_use_id: tool_result.tool_use_id.clone(),
 822                                is_error: tool_result.is_error,
 823                                content: tool_result.content.clone(),
 824                            })
 825                            .collect(),
 826                        context: message.context.clone(),
 827                    })
 828                    .collect(),
 829                initial_project_snapshot,
 830                cumulative_token_usage: this.cumulative_token_usage,
 831                detailed_summary_state: this.detailed_summary_state.clone(),
 832                exceeded_window_error: this.exceeded_window_error.clone(),
 833            })
 834        })
 835    }
 836
 837    pub fn send_to_model(
 838        &mut self,
 839        model: Arc<dyn LanguageModel>,
 840        request_kind: RequestKind,
 841        cx: &mut Context<Self>,
 842    ) {
 843        let mut request = self.to_completion_request(request_kind, cx);
 844        if model.supports_tools() {
 845            request.tools = {
 846                let mut tools = Vec::new();
 847                tools.extend(
 848                    self.tools()
 849                        .read(cx)
 850                        .enabled_tools(cx)
 851                        .into_iter()
 852                        .filter_map(|tool| {
 853                            // Skip tools that cannot be supported
 854                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 855                            Some(LanguageModelRequestTool {
 856                                name: tool.name(),
 857                                description: tool.description(),
 858                                input_schema,
 859                            })
 860                        }),
 861                );
 862
 863                tools
 864            };
 865        }
 866
 867        self.stream_completion(request, model, cx);
 868    }
 869
 870    pub fn used_tools_since_last_user_message(&self) -> bool {
 871        for message in self.messages.iter().rev() {
 872            if self.tool_use.message_has_tool_results(message.id) {
 873                return true;
 874            } else if message.role == Role::User {
 875                return false;
 876            }
 877        }
 878
 879        false
 880    }
 881
 882    pub fn to_completion_request(
 883        &self,
 884        request_kind: RequestKind,
 885        cx: &App,
 886    ) -> LanguageModelRequest {
 887        let mut request = LanguageModelRequest {
 888            messages: vec![],
 889            tools: Vec::new(),
 890            stop: Vec::new(),
 891            temperature: None,
 892        };
 893
 894        if let Some(project_context) = self.project_context.borrow().as_ref() {
 895            if let Some(system_prompt) = self
 896                .prompt_builder
 897                .generate_assistant_system_prompt(project_context)
 898                .context("failed to generate assistant system prompt")
 899                .log_err()
 900            {
 901                request.messages.push(LanguageModelRequestMessage {
 902                    role: Role::System,
 903                    content: vec![MessageContent::Text(system_prompt)],
 904                    cache: true,
 905                });
 906            }
 907        } else {
 908            log::error!("project_context not set.")
 909        }
 910
 911        for message in &self.messages {
 912            let mut request_message = LanguageModelRequestMessage {
 913                role: message.role,
 914                content: Vec::new(),
 915                cache: false,
 916            };
 917
 918            match request_kind {
 919                RequestKind::Chat => {
 920                    self.tool_use
 921                        .attach_tool_results(message.id, &mut request_message);
 922                }
 923                RequestKind::Summarize => {
 924                    // We don't care about tool use during summarization.
 925                    if self.tool_use.message_has_tool_results(message.id) {
 926                        continue;
 927                    }
 928                }
 929            }
 930
 931            if !message.segments.is_empty() {
 932                request_message
 933                    .content
 934                    .push(MessageContent::Text(message.to_string()));
 935            }
 936
 937            match request_kind {
 938                RequestKind::Chat => {
 939                    self.tool_use
 940                        .attach_tool_uses(message.id, &mut request_message);
 941                }
 942                RequestKind::Summarize => {
 943                    // We don't care about tool use during summarization.
 944                }
 945            };
 946
 947            request.messages.push(request_message);
 948        }
 949
 950        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 951        if let Some(last) = request.messages.last_mut() {
 952            last.cache = true;
 953        }
 954
 955        self.attached_tracked_files_state(&mut request.messages, cx);
 956
 957        request
 958    }
 959
 960    fn attached_tracked_files_state(
 961        &self,
 962        messages: &mut Vec<LanguageModelRequestMessage>,
 963        cx: &App,
 964    ) {
 965        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 966
 967        let mut stale_message = String::new();
 968
 969        let action_log = self.action_log.read(cx);
 970
 971        for stale_file in action_log.stale_buffers(cx) {
 972            let Some(file) = stale_file.read(cx).file() else {
 973                continue;
 974            };
 975
 976            if stale_message.is_empty() {
 977                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
 978            }
 979
 980            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 981        }
 982
 983        let mut content = Vec::with_capacity(2);
 984
 985        if !stale_message.is_empty() {
 986            content.push(stale_message.into());
 987        }
 988
 989        if action_log.has_edited_files_since_project_diagnostics_check() {
 990            content.push(
 991                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 992                and fix all errors AND warnings you introduced! \
 993                DO NOT mention you're going to do this until you're done."
 994                    .into(),
 995            );
 996        }
 997
 998        if !content.is_empty() {
 999            let context_message = LanguageModelRequestMessage {
1000                role: Role::User,
1001                content,
1002                cache: false,
1003            };
1004
1005            messages.push(context_message);
1006        }
1007    }
1008
1009    pub fn stream_completion(
1010        &mut self,
1011        request: LanguageModelRequest,
1012        model: Arc<dyn LanguageModel>,
1013        cx: &mut Context<Self>,
1014    ) {
1015        let pending_completion_id = post_inc(&mut self.completion_count);
1016
1017        let task = cx.spawn(async move |thread, cx| {
1018            let stream = model.stream_completion(request, &cx);
1019            let initial_token_usage =
1020                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1021            let stream_completion = async {
1022                let mut events = stream.await?;
1023                let mut stop_reason = StopReason::EndTurn;
1024                let mut current_token_usage = TokenUsage::default();
1025
1026                while let Some(event) = events.next().await {
1027                    let event = event?;
1028
1029                    thread.update(cx, |thread, cx| {
1030                        match event {
1031                            LanguageModelCompletionEvent::StartMessage { .. } => {
1032                                thread.insert_message(
1033                                    Role::Assistant,
1034                                    vec![MessageSegment::Text(String::new())],
1035                                    cx,
1036                                );
1037                            }
1038                            LanguageModelCompletionEvent::Stop(reason) => {
1039                                stop_reason = reason;
1040                            }
1041                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1042                                thread.cumulative_token_usage = thread.cumulative_token_usage
1043                                    + token_usage
1044                                    - current_token_usage;
1045                                current_token_usage = token_usage;
1046                            }
1047                            LanguageModelCompletionEvent::Text(chunk) => {
1048                                if let Some(last_message) = thread.messages.last_mut() {
1049                                    if last_message.role == Role::Assistant {
1050                                        last_message.push_text(&chunk);
1051                                        cx.emit(ThreadEvent::StreamedAssistantText(
1052                                            last_message.id,
1053                                            chunk,
1054                                        ));
1055                                    } else {
1056                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1057                                        // of a new Assistant response.
1058                                        //
1059                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1060                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1061                                        thread.insert_message(
1062                                            Role::Assistant,
1063                                            vec![MessageSegment::Text(chunk.to_string())],
1064                                            cx,
1065                                        );
1066                                    };
1067                                }
1068                            }
1069                            LanguageModelCompletionEvent::Thinking(chunk) => {
1070                                if let Some(last_message) = thread.messages.last_mut() {
1071                                    if last_message.role == Role::Assistant {
1072                                        last_message.push_thinking(&chunk);
1073                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1074                                            last_message.id,
1075                                            chunk,
1076                                        ));
1077                                    } else {
1078                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1079                                        // of a new Assistant response.
1080                                        //
1081                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1082                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1083                                        thread.insert_message(
1084                                            Role::Assistant,
1085                                            vec![MessageSegment::Thinking(chunk.to_string())],
1086                                            cx,
1087                                        );
1088                                    };
1089                                }
1090                            }
1091                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1092                                let last_assistant_message_id = thread
1093                                    .messages
1094                                    .iter_mut()
1095                                    .rfind(|message| message.role == Role::Assistant)
1096                                    .map(|message| message.id)
1097                                    .unwrap_or_else(|| {
1098                                        thread.insert_message(Role::Assistant, vec![], cx)
1099                                    });
1100
1101                                thread.tool_use.request_tool_use(
1102                                    last_assistant_message_id,
1103                                    tool_use,
1104                                    cx,
1105                                );
1106                            }
1107                        }
1108
1109                        thread.touch_updated_at();
1110                        cx.emit(ThreadEvent::StreamedCompletion);
1111                        cx.notify();
1112
1113                        thread.auto_capture_telemetry(cx);
1114                    })?;
1115
1116                    smol::future::yield_now().await;
1117                }
1118
1119                thread.update(cx, |thread, cx| {
1120                    thread
1121                        .pending_completions
1122                        .retain(|completion| completion.id != pending_completion_id);
1123
1124                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1125                        thread.summarize(cx);
1126                    }
1127                })?;
1128
1129                anyhow::Ok(stop_reason)
1130            };
1131
1132            let result = stream_completion.await;
1133
1134            thread
1135                .update(cx, |thread, cx| {
1136                    thread.finalize_pending_checkpoint(cx);
1137                    match result.as_ref() {
1138                        Ok(stop_reason) => match stop_reason {
1139                            StopReason::ToolUse => {
1140                                let tool_uses = thread.use_pending_tools(cx);
1141                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1142                            }
1143                            StopReason::EndTurn => {}
1144                            StopReason::MaxTokens => {}
1145                        },
1146                        Err(error) => {
1147                            if error.is::<PaymentRequiredError>() {
1148                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1149                            } else if error.is::<MaxMonthlySpendReachedError>() {
1150                                cx.emit(ThreadEvent::ShowError(
1151                                    ThreadError::MaxMonthlySpendReached,
1152                                ));
1153                            } else if let Some(known_error) =
1154                                error.downcast_ref::<LanguageModelKnownError>()
1155                            {
1156                                match known_error {
1157                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1158                                        tokens,
1159                                    } => {
1160                                        thread.exceeded_window_error = Some(ExceededWindowError {
1161                                            model_id: model.id(),
1162                                            token_count: *tokens,
1163                                        });
1164                                        cx.notify();
1165                                    }
1166                                }
1167                            } else {
1168                                let error_message = error
1169                                    .chain()
1170                                    .map(|err| err.to_string())
1171                                    .collect::<Vec<_>>()
1172                                    .join("\n");
1173                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1174                                    header: "Error interacting with language model".into(),
1175                                    message: SharedString::from(error_message.clone()),
1176                                }));
1177                            }
1178
1179                            thread.cancel_last_completion(cx);
1180                        }
1181                    }
1182                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1183
1184                    thread.auto_capture_telemetry(cx);
1185
1186                    if let Ok(initial_usage) = initial_token_usage {
1187                        let usage = thread.cumulative_token_usage - initial_usage;
1188
1189                        telemetry::event!(
1190                            "Assistant Thread Completion",
1191                            thread_id = thread.id().to_string(),
1192                            model = model.telemetry_id(),
1193                            model_provider = model.provider_id().to_string(),
1194                            input_tokens = usage.input_tokens,
1195                            output_tokens = usage.output_tokens,
1196                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1197                            cache_read_input_tokens = usage.cache_read_input_tokens,
1198                        );
1199                    }
1200                })
1201                .ok();
1202        });
1203
1204        self.pending_completions.push(PendingCompletion {
1205            id: pending_completion_id,
1206            _task: task,
1207        });
1208    }
1209
1210    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1211        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1212            return;
1213        };
1214
1215        if !model.provider.is_authenticated(cx) {
1216            return;
1217        }
1218
1219        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1220        request.messages.push(LanguageModelRequestMessage {
1221            role: Role::User,
1222            content: vec![
1223                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1224                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1225                 If the conversation is about a specific subject, include it in the title. \
1226                 Be descriptive. DO NOT speak in the first person."
1227                    .into(),
1228            ],
1229            cache: false,
1230        });
1231
1232        self.pending_summary = cx.spawn(async move |this, cx| {
1233            async move {
1234                let stream = model.model.stream_completion_text(request, &cx);
1235                let mut messages = stream.await?;
1236
1237                let mut new_summary = String::new();
1238                while let Some(message) = messages.stream.next().await {
1239                    let text = message?;
1240                    let mut lines = text.lines();
1241                    new_summary.extend(lines.next());
1242
1243                    // Stop if the LLM generated multiple lines.
1244                    if lines.next().is_some() {
1245                        break;
1246                    }
1247                }
1248
1249                this.update(cx, |this, cx| {
1250                    if !new_summary.is_empty() {
1251                        this.summary = Some(new_summary.into());
1252                    }
1253
1254                    cx.emit(ThreadEvent::SummaryGenerated);
1255                })?;
1256
1257                anyhow::Ok(())
1258            }
1259            .log_err()
1260            .await
1261        });
1262    }
1263
1264    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1265        let last_message_id = self.messages.last().map(|message| message.id)?;
1266
1267        match &self.detailed_summary_state {
1268            DetailedSummaryState::Generating { message_id, .. }
1269            | DetailedSummaryState::Generated { message_id, .. }
1270                if *message_id == last_message_id =>
1271            {
1272                // Already up-to-date
1273                return None;
1274            }
1275            _ => {}
1276        }
1277
1278        let ConfiguredModel { model, provider } =
1279            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1280
1281        if !provider.is_authenticated(cx) {
1282            return None;
1283        }
1284
1285        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1286
1287        request.messages.push(LanguageModelRequestMessage {
1288            role: Role::User,
1289            content: vec![
1290                "Generate a detailed summary of this conversation. Include:\n\
1291                1. A brief overview of what was discussed\n\
1292                2. Key facts or information discovered\n\
1293                3. Outcomes or conclusions reached\n\
1294                4. Any action items or next steps if any\n\
1295                Format it in Markdown with headings and bullet points."
1296                    .into(),
1297            ],
1298            cache: false,
1299        });
1300
1301        let task = cx.spawn(async move |thread, cx| {
1302            let stream = model.stream_completion_text(request, &cx);
1303            let Some(mut messages) = stream.await.log_err() else {
1304                thread
1305                    .update(cx, |this, _cx| {
1306                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1307                    })
1308                    .log_err();
1309
1310                return;
1311            };
1312
1313            let mut new_detailed_summary = String::new();
1314
1315            while let Some(chunk) = messages.stream.next().await {
1316                if let Some(chunk) = chunk.log_err() {
1317                    new_detailed_summary.push_str(&chunk);
1318                }
1319            }
1320
1321            thread
1322                .update(cx, |this, _cx| {
1323                    this.detailed_summary_state = DetailedSummaryState::Generated {
1324                        text: new_detailed_summary.into(),
1325                        message_id: last_message_id,
1326                    };
1327                })
1328                .log_err();
1329        });
1330
1331        self.detailed_summary_state = DetailedSummaryState::Generating {
1332            message_id: last_message_id,
1333        };
1334
1335        Some(task)
1336    }
1337
1338    pub fn is_generating_detailed_summary(&self) -> bool {
1339        matches!(
1340            self.detailed_summary_state,
1341            DetailedSummaryState::Generating { .. }
1342        )
1343    }
1344
1345    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1346        self.auto_capture_telemetry(cx);
1347        let request = self.to_completion_request(RequestKind::Chat, cx);
1348        let messages = Arc::new(request.messages);
1349        let pending_tool_uses = self
1350            .tool_use
1351            .pending_tool_uses()
1352            .into_iter()
1353            .filter(|tool_use| tool_use.status.is_idle())
1354            .cloned()
1355            .collect::<Vec<_>>();
1356
1357        for tool_use in pending_tool_uses.iter() {
1358            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1359                if tool.needs_confirmation(&tool_use.input, cx)
1360                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1361                {
1362                    self.tool_use.confirm_tool_use(
1363                        tool_use.id.clone(),
1364                        tool_use.ui_text.clone(),
1365                        tool_use.input.clone(),
1366                        messages.clone(),
1367                        tool,
1368                    );
1369                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1370                } else {
1371                    self.run_tool(
1372                        tool_use.id.clone(),
1373                        tool_use.ui_text.clone(),
1374                        tool_use.input.clone(),
1375                        &messages,
1376                        tool,
1377                        cx,
1378                    );
1379                }
1380            }
1381        }
1382
1383        pending_tool_uses
1384    }
1385
1386    pub fn run_tool(
1387        &mut self,
1388        tool_use_id: LanguageModelToolUseId,
1389        ui_text: impl Into<SharedString>,
1390        input: serde_json::Value,
1391        messages: &[LanguageModelRequestMessage],
1392        tool: Arc<dyn Tool>,
1393        cx: &mut Context<Thread>,
1394    ) {
1395        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1396        self.tool_use
1397            .run_pending_tool(tool_use_id, ui_text.into(), task);
1398    }
1399
1400    fn spawn_tool_use(
1401        &mut self,
1402        tool_use_id: LanguageModelToolUseId,
1403        messages: &[LanguageModelRequestMessage],
1404        input: serde_json::Value,
1405        tool: Arc<dyn Tool>,
1406        cx: &mut Context<Thread>,
1407    ) -> Task<()> {
1408        let tool_name: Arc<str> = tool.name().into();
1409
1410        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1411            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1412        } else {
1413            tool.run(
1414                input,
1415                messages,
1416                self.project.clone(),
1417                self.action_log.clone(),
1418                cx,
1419            )
1420        };
1421
1422        cx.spawn({
1423            async move |thread: WeakEntity<Thread>, cx| {
1424                let output = tool_result.output.await;
1425
1426                thread
1427                    .update(cx, |thread, cx| {
1428                        let pending_tool_use = thread.tool_use.insert_tool_output(
1429                            tool_use_id.clone(),
1430                            tool_name,
1431                            output,
1432                            cx,
1433                        );
1434                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1435                    })
1436                    .ok();
1437            }
1438        })
1439    }
1440
1441    fn tool_finished(
1442        &mut self,
1443        tool_use_id: LanguageModelToolUseId,
1444        pending_tool_use: Option<PendingToolUse>,
1445        canceled: bool,
1446        cx: &mut Context<Self>,
1447    ) {
1448        if self.all_tools_finished() {
1449            let model_registry = LanguageModelRegistry::read_global(cx);
1450            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1451                self.attach_tool_results(cx);
1452                if !canceled {
1453                    self.send_to_model(model, RequestKind::Chat, cx);
1454                }
1455            }
1456        }
1457
1458        cx.emit(ThreadEvent::ToolFinished {
1459            tool_use_id,
1460            pending_tool_use,
1461        });
1462    }
1463
1464    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1465        // Insert a user message to contain the tool results.
1466        self.insert_user_message(
1467            // TODO: Sending up a user message without any content results in the model sending back
1468            // responses that also don't have any content. We currently don't handle this case well,
1469            // so for now we provide some text to keep the model on track.
1470            "Here are the tool results.",
1471            Vec::new(),
1472            None,
1473            cx,
1474        );
1475    }
1476
1477    /// Cancels the last pending completion, if there are any pending.
1478    ///
1479    /// Returns whether a completion was canceled.
1480    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1481        let canceled = if self.pending_completions.pop().is_some() {
1482            true
1483        } else {
1484            let mut canceled = false;
1485            for pending_tool_use in self.tool_use.cancel_pending() {
1486                canceled = true;
1487                self.tool_finished(
1488                    pending_tool_use.id.clone(),
1489                    Some(pending_tool_use),
1490                    true,
1491                    cx,
1492                );
1493            }
1494            canceled
1495        };
1496        self.finalize_pending_checkpoint(cx);
1497        canceled
1498    }
1499
1500    pub fn feedback(&self) -> Option<ThreadFeedback> {
1501        self.feedback
1502    }
1503
1504    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1505        self.message_feedback.get(&message_id).copied()
1506    }
1507
1508    pub fn report_message_feedback(
1509        &mut self,
1510        message_id: MessageId,
1511        feedback: ThreadFeedback,
1512        cx: &mut Context<Self>,
1513    ) -> Task<Result<()>> {
1514        if self.message_feedback.get(&message_id) == Some(&feedback) {
1515            return Task::ready(Ok(()));
1516        }
1517
1518        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1519        let serialized_thread = self.serialize(cx);
1520        let thread_id = self.id().clone();
1521        let client = self.project.read(cx).client();
1522
1523        let enabled_tool_names: Vec<String> = self
1524            .tools()
1525            .read(cx)
1526            .enabled_tools(cx)
1527            .iter()
1528            .map(|tool| tool.name().to_string())
1529            .collect();
1530
1531        self.message_feedback.insert(message_id, feedback);
1532
1533        cx.notify();
1534
1535        let message_content = self
1536            .message(message_id)
1537            .map(|msg| msg.to_string())
1538            .unwrap_or_default();
1539
1540        cx.background_spawn(async move {
1541            let final_project_snapshot = final_project_snapshot.await;
1542            let serialized_thread = serialized_thread.await?;
1543            let thread_data =
1544                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1545
1546            let rating = match feedback {
1547                ThreadFeedback::Positive => "positive",
1548                ThreadFeedback::Negative => "negative",
1549            };
1550            telemetry::event!(
1551                "Assistant Thread Rated",
1552                rating,
1553                thread_id,
1554                enabled_tool_names,
1555                message_id = message_id.0,
1556                message_content,
1557                thread_data,
1558                final_project_snapshot
1559            );
1560            client.telemetry().flush_events();
1561
1562            Ok(())
1563        })
1564    }
1565
1566    pub fn report_feedback(
1567        &mut self,
1568        feedback: ThreadFeedback,
1569        cx: &mut Context<Self>,
1570    ) -> Task<Result<()>> {
1571        let last_assistant_message_id = self
1572            .messages
1573            .iter()
1574            .rev()
1575            .find(|msg| msg.role == Role::Assistant)
1576            .map(|msg| msg.id);
1577
1578        if let Some(message_id) = last_assistant_message_id {
1579            self.report_message_feedback(message_id, feedback, cx)
1580        } else {
1581            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1582            let serialized_thread = self.serialize(cx);
1583            let thread_id = self.id().clone();
1584            let client = self.project.read(cx).client();
1585            self.feedback = Some(feedback);
1586            cx.notify();
1587
1588            cx.background_spawn(async move {
1589                let final_project_snapshot = final_project_snapshot.await;
1590                let serialized_thread = serialized_thread.await?;
1591                let thread_data = serde_json::to_value(serialized_thread)
1592                    .unwrap_or_else(|_| serde_json::Value::Null);
1593
1594                let rating = match feedback {
1595                    ThreadFeedback::Positive => "positive",
1596                    ThreadFeedback::Negative => "negative",
1597                };
1598                telemetry::event!(
1599                    "Assistant Thread Rated",
1600                    rating,
1601                    thread_id,
1602                    thread_data,
1603                    final_project_snapshot
1604                );
1605                client.telemetry().flush_events();
1606
1607                Ok(())
1608            })
1609        }
1610    }
1611
1612    /// Create a snapshot of the current project state including git information and unsaved buffers.
1613    fn project_snapshot(
1614        project: Entity<Project>,
1615        cx: &mut Context<Self>,
1616    ) -> Task<Arc<ProjectSnapshot>> {
1617        let git_store = project.read(cx).git_store().clone();
1618        let worktree_snapshots: Vec<_> = project
1619            .read(cx)
1620            .visible_worktrees(cx)
1621            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1622            .collect();
1623
1624        cx.spawn(async move |_, cx| {
1625            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1626
1627            let mut unsaved_buffers = Vec::new();
1628            cx.update(|app_cx| {
1629                let buffer_store = project.read(app_cx).buffer_store();
1630                for buffer_handle in buffer_store.read(app_cx).buffers() {
1631                    let buffer = buffer_handle.read(app_cx);
1632                    if buffer.is_dirty() {
1633                        if let Some(file) = buffer.file() {
1634                            let path = file.path().to_string_lossy().to_string();
1635                            unsaved_buffers.push(path);
1636                        }
1637                    }
1638                }
1639            })
1640            .ok();
1641
1642            Arc::new(ProjectSnapshot {
1643                worktree_snapshots,
1644                unsaved_buffer_paths: unsaved_buffers,
1645                timestamp: Utc::now(),
1646            })
1647        })
1648    }
1649
1650    fn worktree_snapshot(
1651        worktree: Entity<project::Worktree>,
1652        git_store: Entity<GitStore>,
1653        cx: &App,
1654    ) -> Task<WorktreeSnapshot> {
1655        cx.spawn(async move |cx| {
1656            // Get worktree path and snapshot
1657            let worktree_info = cx.update(|app_cx| {
1658                let worktree = worktree.read(app_cx);
1659                let path = worktree.abs_path().to_string_lossy().to_string();
1660                let snapshot = worktree.snapshot();
1661                (path, snapshot)
1662            });
1663
1664            let Ok((worktree_path, _snapshot)) = worktree_info else {
1665                return WorktreeSnapshot {
1666                    worktree_path: String::new(),
1667                    git_state: None,
1668                };
1669            };
1670
1671            let git_state = git_store
1672                .update(cx, |git_store, cx| {
1673                    git_store
1674                        .repositories()
1675                        .values()
1676                        .find(|repo| {
1677                            repo.read(cx)
1678                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1679                                .is_some()
1680                        })
1681                        .cloned()
1682                })
1683                .ok()
1684                .flatten()
1685                .map(|repo| {
1686                    repo.update(cx, |repo, _| {
1687                        let current_branch =
1688                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1689                        repo.send_job(None, |state, _| async move {
1690                            let RepositoryState::Local { backend, .. } = state else {
1691                                return GitState {
1692                                    remote_url: None,
1693                                    head_sha: None,
1694                                    current_branch,
1695                                    diff: None,
1696                                };
1697                            };
1698
1699                            let remote_url = backend.remote_url("origin");
1700                            let head_sha = backend.head_sha();
1701                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1702
1703                            GitState {
1704                                remote_url,
1705                                head_sha,
1706                                current_branch,
1707                                diff,
1708                            }
1709                        })
1710                    })
1711                });
1712
1713            let git_state = match git_state {
1714                Some(git_state) => match git_state.ok() {
1715                    Some(git_state) => git_state.await.ok(),
1716                    None => None,
1717                },
1718                None => None,
1719            };
1720
1721            WorktreeSnapshot {
1722                worktree_path,
1723                git_state,
1724            }
1725        })
1726    }
1727
1728    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1729        let mut markdown = Vec::new();
1730
1731        if let Some(summary) = self.summary() {
1732            writeln!(markdown, "# {summary}\n")?;
1733        };
1734
1735        for message in self.messages() {
1736            writeln!(
1737                markdown,
1738                "## {role}\n",
1739                role = match message.role {
1740                    Role::User => "User",
1741                    Role::Assistant => "Assistant",
1742                    Role::System => "System",
1743                }
1744            )?;
1745
1746            if !message.context.is_empty() {
1747                writeln!(markdown, "{}", message.context)?;
1748            }
1749
1750            for segment in &message.segments {
1751                match segment {
1752                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1753                    MessageSegment::Thinking(text) => {
1754                        writeln!(markdown, "<think>{}</think>\n", text)?
1755                    }
1756                }
1757            }
1758
1759            for tool_use in self.tool_uses_for_message(message.id, cx) {
1760                writeln!(
1761                    markdown,
1762                    "**Use Tool: {} ({})**",
1763                    tool_use.name, tool_use.id
1764                )?;
1765                writeln!(markdown, "```json")?;
1766                writeln!(
1767                    markdown,
1768                    "{}",
1769                    serde_json::to_string_pretty(&tool_use.input)?
1770                )?;
1771                writeln!(markdown, "```")?;
1772            }
1773
1774            for tool_result in self.tool_results_for_message(message.id) {
1775                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1776                if tool_result.is_error {
1777                    write!(markdown, " (Error)")?;
1778                }
1779
1780                writeln!(markdown, "**\n")?;
1781                writeln!(markdown, "{}", tool_result.content)?;
1782            }
1783        }
1784
1785        Ok(String::from_utf8_lossy(&markdown).to_string())
1786    }
1787
1788    pub fn keep_edits_in_range(
1789        &mut self,
1790        buffer: Entity<language::Buffer>,
1791        buffer_range: Range<language::Anchor>,
1792        cx: &mut Context<Self>,
1793    ) {
1794        self.action_log.update(cx, |action_log, cx| {
1795            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1796        });
1797    }
1798
1799    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1800        self.action_log
1801            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1802    }
1803
1804    pub fn reject_edits_in_ranges(
1805        &mut self,
1806        buffer: Entity<language::Buffer>,
1807        buffer_ranges: Vec<Range<language::Anchor>>,
1808        cx: &mut Context<Self>,
1809    ) -> Task<Result<()>> {
1810        self.action_log.update(cx, |action_log, cx| {
1811            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1812        })
1813    }
1814
1815    pub fn action_log(&self) -> &Entity<ActionLog> {
1816        &self.action_log
1817    }
1818
1819    pub fn project(&self) -> &Entity<Project> {
1820        &self.project
1821    }
1822
1823    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1824        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1825            return;
1826        }
1827
1828        let now = Instant::now();
1829        if let Some(last) = self.last_auto_capture_at {
1830            if now.duration_since(last).as_secs() < 10 {
1831                return;
1832            }
1833        }
1834
1835        self.last_auto_capture_at = Some(now);
1836
1837        let thread_id = self.id().clone();
1838        let github_login = self
1839            .project
1840            .read(cx)
1841            .user_store()
1842            .read(cx)
1843            .current_user()
1844            .map(|user| user.github_login.clone());
1845        let client = self.project.read(cx).client().clone();
1846        let serialize_task = self.serialize(cx);
1847
1848        cx.background_executor()
1849            .spawn(async move {
1850                if let Ok(serialized_thread) = serialize_task.await {
1851                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1852                        telemetry::event!(
1853                            "Agent Thread Auto-Captured",
1854                            thread_id = thread_id.to_string(),
1855                            thread_data = thread_data,
1856                            auto_capture_reason = "tracked_user",
1857                            github_login = github_login
1858                        );
1859
1860                        client.telemetry().flush_events();
1861                    }
1862                }
1863            })
1864            .detach();
1865    }
1866
1867    pub fn cumulative_token_usage(&self) -> TokenUsage {
1868        self.cumulative_token_usage
1869    }
1870
1871    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1872        let model_registry = LanguageModelRegistry::read_global(cx);
1873        let Some(model) = model_registry.default_model() else {
1874            return TotalTokenUsage::default();
1875        };
1876
1877        let max = model.model.max_token_count();
1878
1879        if let Some(exceeded_error) = &self.exceeded_window_error {
1880            if model.model.id() == exceeded_error.model_id {
1881                return TotalTokenUsage {
1882                    total: exceeded_error.token_count,
1883                    max,
1884                    ratio: TokenUsageRatio::Exceeded,
1885                };
1886            }
1887        }
1888
1889        #[cfg(debug_assertions)]
1890        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1891            .unwrap_or("0.8".to_string())
1892            .parse()
1893            .unwrap();
1894        #[cfg(not(debug_assertions))]
1895        let warning_threshold: f32 = 0.8;
1896
1897        let total = self.cumulative_token_usage.total_tokens() as usize;
1898
1899        let ratio = if total >= max {
1900            TokenUsageRatio::Exceeded
1901        } else if total as f32 / max as f32 >= warning_threshold {
1902            TokenUsageRatio::Warning
1903        } else {
1904            TokenUsageRatio::Normal
1905        };
1906
1907        TotalTokenUsage { total, max, ratio }
1908    }
1909
1910    pub fn deny_tool_use(
1911        &mut self,
1912        tool_use_id: LanguageModelToolUseId,
1913        tool_name: Arc<str>,
1914        cx: &mut Context<Self>,
1915    ) {
1916        let err = Err(anyhow::anyhow!(
1917            "Permission to run tool action denied by user"
1918        ));
1919
1920        self.tool_use
1921            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1922        self.tool_finished(tool_use_id.clone(), None, true, cx);
1923    }
1924}
1925
1926#[derive(Debug, Clone, Error)]
1927pub enum ThreadError {
1928    #[error("Payment required")]
1929    PaymentRequired,
1930    #[error("Max monthly spend reached")]
1931    MaxMonthlySpendReached,
1932    #[error("Message {header}: {message}")]
1933    Message {
1934        header: SharedString,
1935        message: SharedString,
1936    },
1937}
1938
1939#[derive(Debug, Clone)]
1940pub enum ThreadEvent {
1941    ShowError(ThreadError),
1942    StreamedCompletion,
1943    StreamedAssistantText(MessageId, String),
1944    StreamedAssistantThinking(MessageId, String),
1945    Stopped(Result<StopReason, Arc<anyhow::Error>>),
1946    MessageAdded(MessageId),
1947    MessageEdited(MessageId),
1948    MessageDeleted(MessageId),
1949    SummaryGenerated,
1950    SummaryChanged,
1951    UsePendingTools {
1952        tool_uses: Vec<PendingToolUse>,
1953    },
1954    ToolFinished {
1955        #[allow(unused)]
1956        tool_use_id: LanguageModelToolUseId,
1957        /// The pending tool use that corresponds to this tool.
1958        pending_tool_use: Option<PendingToolUse>,
1959    },
1960    CheckpointChanged,
1961    ToolConfirmationNeeded,
1962}
1963
1964impl EventEmitter<ThreadEvent> for Thread {}
1965
1966struct PendingCompletion {
1967    id: usize,
1968    _task: Task<()>,
1969}
1970
1971#[cfg(test)]
1972mod tests {
1973    use super::*;
1974    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1975    use assistant_settings::AssistantSettings;
1976    use context_server::ContextServerSettings;
1977    use editor::EditorSettings;
1978    use gpui::TestAppContext;
1979    use project::{FakeFs, Project};
1980    use prompt_store::PromptBuilder;
1981    use serde_json::json;
1982    use settings::{Settings, SettingsStore};
1983    use std::sync::Arc;
1984    use theme::ThemeSettings;
1985    use util::path;
1986    use workspace::Workspace;
1987
1988    #[gpui::test]
1989    async fn test_message_with_context(cx: &mut TestAppContext) {
1990        init_test_settings(cx);
1991
1992        let project = create_test_project(
1993            cx,
1994            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
1995        )
1996        .await;
1997
1998        let (_workspace, _thread_store, thread, context_store) =
1999            setup_test_environment(cx, project.clone()).await;
2000
2001        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2002            .await
2003            .unwrap();
2004
2005        let context =
2006            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2007
2008        // Insert user message with context
2009        let message_id = thread.update(cx, |thread, cx| {
2010            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2011        });
2012
2013        // Check content and context in message object
2014        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2015
2016        // Use different path format strings based on platform for the test
2017        #[cfg(windows)]
2018        let path_part = r"test\code.rs";
2019        #[cfg(not(windows))]
2020        let path_part = "test/code.rs";
2021
2022        let expected_context = format!(
2023            r#"
2024<context>
2025The following items were attached by the user. You don't need to use other tools to read them.
2026
2027<files>
2028```rs {path_part}
2029fn main() {{
2030    println!("Hello, world!");
2031}}
2032```
2033</files>
2034</context>
2035"#
2036        );
2037
2038        assert_eq!(message.role, Role::User);
2039        assert_eq!(message.segments.len(), 1);
2040        assert_eq!(
2041            message.segments[0],
2042            MessageSegment::Text("Please explain this code".to_string())
2043        );
2044        assert_eq!(message.context, expected_context);
2045
2046        // Check message in request
2047        let request = thread.read_with(cx, |thread, cx| {
2048            thread.to_completion_request(RequestKind::Chat, cx)
2049        });
2050
2051        assert_eq!(request.messages.len(), 2);
2052        let expected_full_message = format!("{}Please explain this code", expected_context);
2053        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2054    }
2055
2056    #[gpui::test]
2057    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2058        init_test_settings(cx);
2059
2060        let project = create_test_project(
2061            cx,
2062            json!({
2063                "file1.rs": "fn function1() {}\n",
2064                "file2.rs": "fn function2() {}\n",
2065                "file3.rs": "fn function3() {}\n",
2066            }),
2067        )
2068        .await;
2069
2070        let (_, _thread_store, thread, context_store) =
2071            setup_test_environment(cx, project.clone()).await;
2072
2073        // Open files individually
2074        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2075            .await
2076            .unwrap();
2077        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2078            .await
2079            .unwrap();
2080        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2081            .await
2082            .unwrap();
2083
2084        // Get the context objects
2085        let contexts = context_store.update(cx, |store, _| store.context().clone());
2086        assert_eq!(contexts.len(), 3);
2087
2088        // First message with context 1
2089        let message1_id = thread.update(cx, |thread, cx| {
2090            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2091        });
2092
2093        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2094        let message2_id = thread.update(cx, |thread, cx| {
2095            thread.insert_user_message(
2096                "Message 2",
2097                vec![contexts[0].clone(), contexts[1].clone()],
2098                None,
2099                cx,
2100            )
2101        });
2102
2103        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2104        let message3_id = thread.update(cx, |thread, cx| {
2105            thread.insert_user_message(
2106                "Message 3",
2107                vec![
2108                    contexts[0].clone(),
2109                    contexts[1].clone(),
2110                    contexts[2].clone(),
2111                ],
2112                None,
2113                cx,
2114            )
2115        });
2116
2117        // Check what contexts are included in each message
2118        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2119            (
2120                thread.message(message1_id).unwrap().clone(),
2121                thread.message(message2_id).unwrap().clone(),
2122                thread.message(message3_id).unwrap().clone(),
2123            )
2124        });
2125
2126        // First message should include context 1
2127        assert!(message1.context.contains("file1.rs"));
2128
2129        // Second message should include only context 2 (not 1)
2130        assert!(!message2.context.contains("file1.rs"));
2131        assert!(message2.context.contains("file2.rs"));
2132
2133        // Third message should include only context 3 (not 1 or 2)
2134        assert!(!message3.context.contains("file1.rs"));
2135        assert!(!message3.context.contains("file2.rs"));
2136        assert!(message3.context.contains("file3.rs"));
2137
2138        // Check entire request to make sure all contexts are properly included
2139        let request = thread.read_with(cx, |thread, cx| {
2140            thread.to_completion_request(RequestKind::Chat, cx)
2141        });
2142
2143        // The request should contain all 3 messages
2144        assert_eq!(request.messages.len(), 4);
2145
2146        // Check that the contexts are properly formatted in each message
2147        assert!(request.messages[1].string_contents().contains("file1.rs"));
2148        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2149        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2150
2151        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2152        assert!(request.messages[2].string_contents().contains("file2.rs"));
2153        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2154
2155        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2156        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2157        assert!(request.messages[3].string_contents().contains("file3.rs"));
2158    }
2159
2160    #[gpui::test]
2161    async fn test_message_without_files(cx: &mut TestAppContext) {
2162        init_test_settings(cx);
2163
2164        let project = create_test_project(
2165            cx,
2166            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2167        )
2168        .await;
2169
2170        let (_, _thread_store, thread, _context_store) =
2171            setup_test_environment(cx, project.clone()).await;
2172
2173        // Insert user message without any context (empty context vector)
2174        let message_id = thread.update(cx, |thread, cx| {
2175            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2176        });
2177
2178        // Check content and context in message object
2179        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2180
2181        // Context should be empty when no files are included
2182        assert_eq!(message.role, Role::User);
2183        assert_eq!(message.segments.len(), 1);
2184        assert_eq!(
2185            message.segments[0],
2186            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2187        );
2188        assert_eq!(message.context, "");
2189
2190        // Check message in request
2191        let request = thread.read_with(cx, |thread, cx| {
2192            thread.to_completion_request(RequestKind::Chat, cx)
2193        });
2194
2195        assert_eq!(request.messages.len(), 2);
2196        assert_eq!(
2197            request.messages[1].string_contents(),
2198            "What is the best way to learn Rust?"
2199        );
2200
2201        // Add second message, also without context
2202        let message2_id = thread.update(cx, |thread, cx| {
2203            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2204        });
2205
2206        let message2 =
2207            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2208        assert_eq!(message2.context, "");
2209
2210        // Check that both messages appear in the request
2211        let request = thread.read_with(cx, |thread, cx| {
2212            thread.to_completion_request(RequestKind::Chat, cx)
2213        });
2214
2215        assert_eq!(request.messages.len(), 3);
2216        assert_eq!(
2217            request.messages[1].string_contents(),
2218            "What is the best way to learn Rust?"
2219        );
2220        assert_eq!(
2221            request.messages[2].string_contents(),
2222            "Are there any good books?"
2223        );
2224    }
2225
2226    #[gpui::test]
2227    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2228        init_test_settings(cx);
2229
2230        let project = create_test_project(
2231            cx,
2232            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2233        )
2234        .await;
2235
2236        let (_workspace, _thread_store, thread, context_store) =
2237            setup_test_environment(cx, project.clone()).await;
2238
2239        // Open buffer and add it to context
2240        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2241            .await
2242            .unwrap();
2243
2244        let context =
2245            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2246
2247        // Insert user message with the buffer as context
2248        thread.update(cx, |thread, cx| {
2249            thread.insert_user_message("Explain this code", vec![context], None, cx)
2250        });
2251
2252        // Create a request and check that it doesn't have a stale buffer warning yet
2253        let initial_request = thread.read_with(cx, |thread, cx| {
2254            thread.to_completion_request(RequestKind::Chat, cx)
2255        });
2256
2257        // Make sure we don't have a stale file warning yet
2258        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2259            msg.string_contents()
2260                .contains("These files changed since last read:")
2261        });
2262        assert!(
2263            !has_stale_warning,
2264            "Should not have stale buffer warning before buffer is modified"
2265        );
2266
2267        // Modify the buffer
2268        buffer.update(cx, |buffer, cx| {
2269            // Find a position at the end of line 1
2270            buffer.edit(
2271                [(1..1, "\n    println!(\"Added a new line\");\n")],
2272                None,
2273                cx,
2274            );
2275        });
2276
2277        // Insert another user message without context
2278        thread.update(cx, |thread, cx| {
2279            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2280        });
2281
2282        // Create a new request and check for the stale buffer warning
2283        let new_request = thread.read_with(cx, |thread, cx| {
2284            thread.to_completion_request(RequestKind::Chat, cx)
2285        });
2286
2287        // We should have a stale file warning as the last message
2288        let last_message = new_request
2289            .messages
2290            .last()
2291            .expect("Request should have messages");
2292
2293        // The last message should be the stale buffer notification
2294        assert_eq!(last_message.role, Role::User);
2295
2296        // Check the exact content of the message
2297        let expected_content = "These files changed since last read:\n- code.rs\n";
2298        assert_eq!(
2299            last_message.string_contents(),
2300            expected_content,
2301            "Last message should be exactly the stale buffer notification"
2302        );
2303    }
2304
2305    fn init_test_settings(cx: &mut TestAppContext) {
2306        cx.update(|cx| {
2307            let settings_store = SettingsStore::test(cx);
2308            cx.set_global(settings_store);
2309            language::init(cx);
2310            Project::init_settings(cx);
2311            AssistantSettings::register(cx);
2312            thread_store::init(cx);
2313            workspace::init_settings(cx);
2314            ThemeSettings::register(cx);
2315            ContextServerSettings::register(cx);
2316            EditorSettings::register(cx);
2317        });
2318    }
2319
2320    // Helper to create a test project with test files
2321    async fn create_test_project(
2322        cx: &mut TestAppContext,
2323        files: serde_json::Value,
2324    ) -> Entity<Project> {
2325        let fs = FakeFs::new(cx.executor());
2326        fs.insert_tree(path!("/test"), files).await;
2327        Project::test(fs, [path!("/test").as_ref()], cx).await
2328    }
2329
2330    async fn setup_test_environment(
2331        cx: &mut TestAppContext,
2332        project: Entity<Project>,
2333    ) -> (
2334        Entity<Workspace>,
2335        Entity<ThreadStore>,
2336        Entity<Thread>,
2337        Entity<ContextStore>,
2338    ) {
2339        let (workspace, cx) =
2340            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2341
2342        let thread_store = cx
2343            .update(|_, cx| {
2344                ThreadStore::load(
2345                    project.clone(),
2346                    cx.new(|_| ToolWorkingSet::default()),
2347                    Arc::new(PromptBuilder::new(None).unwrap()),
2348                    cx,
2349                )
2350            })
2351            .await;
2352
2353        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2354        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2355
2356        (workspace, thread_store, thread, context_store)
2357    }
2358
2359    async fn add_file_to_context(
2360        project: &Entity<Project>,
2361        context_store: &Entity<ContextStore>,
2362        path: &str,
2363        cx: &mut TestAppContext,
2364    ) -> Result<Entity<language::Buffer>> {
2365        let buffer_path = project
2366            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2367            .unwrap();
2368
2369        let buffer = project
2370            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2371            .await
2372            .unwrap();
2373
2374        context_store
2375            .update(cx, |store, cx| {
2376                store.add_file_from_buffer(buffer.clone(), cx)
2377            })
2378            .await?;
2379
2380        Ok(buffer)
2381    }
2382}