thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, AnyToolCard, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
  23};
  24use project::Project;
  25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  26use prompt_store::PromptBuilder;
  27use proto::Plan;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings;
  31use thiserror::Error;
  32use util::{ResultExt as _, TryFutureExt as _, post_inc};
  33use uuid::Uuid;
  34
  35use crate::context::{AssistantContext, ContextId, format_context_as_string};
  36use crate::thread_store::{
  37    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  38    SerializedToolUse, SharedProjectContext,
  39};
  40use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  41
  42#[derive(Debug, Clone, Copy)]
  43pub enum RequestKind {
  44    Chat,
  45    /// Used when summarizing a thread.
  46    Summarize,
  47}
  48
  49#[derive(
  50    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  51)]
  52pub struct ThreadId(Arc<str>);
  53
  54impl ThreadId {
  55    pub fn new() -> Self {
  56        Self(Uuid::new_v4().to_string().into())
  57    }
  58}
  59
  60impl std::fmt::Display for ThreadId {
  61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  62        write!(f, "{}", self.0)
  63    }
  64}
  65
  66impl From<&str> for ThreadId {
  67    fn from(value: &str) -> Self {
  68        Self(value.into())
  69    }
  70}
  71
  72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  73pub struct MessageId(pub(crate) usize);
  74
  75impl MessageId {
  76    fn post_inc(&mut self) -> Self {
  77        Self(post_inc(&mut self.0))
  78    }
  79}
  80
  81/// A message in a [`Thread`].
  82#[derive(Debug, Clone)]
  83pub struct Message {
  84    pub id: MessageId,
  85    pub role: Role,
  86    pub segments: Vec<MessageSegment>,
  87    pub context: String,
  88}
  89
  90impl Message {
  91    /// Returns whether the message contains any meaningful text that should be displayed
  92    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  93    pub fn should_display_content(&self) -> bool {
  94        self.segments.iter().all(|segment| segment.should_display())
  95    }
  96
  97    pub fn push_thinking(&mut self, text: &str) {
  98        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  99            segment.push_str(text);
 100        } else {
 101            self.segments
 102                .push(MessageSegment::Thinking(text.to_string()));
 103        }
 104    }
 105
 106    pub fn push_text(&mut self, text: &str) {
 107        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 108            segment.push_str(text);
 109        } else {
 110            self.segments.push(MessageSegment::Text(text.to_string()));
 111        }
 112    }
 113
 114    pub fn to_string(&self) -> String {
 115        let mut result = String::new();
 116
 117        if !self.context.is_empty() {
 118            result.push_str(&self.context);
 119        }
 120
 121        for segment in &self.segments {
 122            match segment {
 123                MessageSegment::Text(text) => result.push_str(text),
 124                MessageSegment::Thinking(text) => {
 125                    result.push_str("<think>");
 126                    result.push_str(text);
 127                    result.push_str("</think>");
 128                }
 129            }
 130        }
 131
 132        result
 133    }
 134}
 135
 136#[derive(Debug, Clone, PartialEq, Eq)]
 137pub enum MessageSegment {
 138    Text(String),
 139    Thinking(String),
 140}
 141
 142impl MessageSegment {
 143    pub fn text_mut(&mut self) -> &mut String {
 144        match self {
 145            Self::Text(text) => text,
 146            Self::Thinking(text) => text,
 147        }
 148    }
 149
 150    pub fn should_display(&self) -> bool {
 151        // We add USING_TOOL_MARKER when making a request that includes tool uses
 152        // without non-whitespace text around them, and this can cause the model
 153        // to mimic the pattern, so we consider those segments not displayable.
 154        match self {
 155            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 156            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 157        }
 158    }
 159}
 160
 161#[derive(Debug, Clone, Serialize, Deserialize)]
 162pub struct ProjectSnapshot {
 163    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 164    pub unsaved_buffer_paths: Vec<String>,
 165    pub timestamp: DateTime<Utc>,
 166}
 167
 168#[derive(Debug, Clone, Serialize, Deserialize)]
 169pub struct WorktreeSnapshot {
 170    pub worktree_path: String,
 171    pub git_state: Option<GitState>,
 172}
 173
 174#[derive(Debug, Clone, Serialize, Deserialize)]
 175pub struct GitState {
 176    pub remote_url: Option<String>,
 177    pub head_sha: Option<String>,
 178    pub current_branch: Option<String>,
 179    pub diff: Option<String>,
 180}
 181
 182#[derive(Clone)]
 183pub struct ThreadCheckpoint {
 184    message_id: MessageId,
 185    git_checkpoint: GitStoreCheckpoint,
 186}
 187
 188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 189pub enum ThreadFeedback {
 190    Positive,
 191    Negative,
 192}
 193
 194pub enum LastRestoreCheckpoint {
 195    Pending {
 196        message_id: MessageId,
 197    },
 198    Error {
 199        message_id: MessageId,
 200        error: String,
 201    },
 202}
 203
 204impl LastRestoreCheckpoint {
 205    pub fn message_id(&self) -> MessageId {
 206        match self {
 207            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 208            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 209        }
 210    }
 211}
 212
 213#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 214pub enum DetailedSummaryState {
 215    #[default]
 216    NotGenerated,
 217    Generating {
 218        message_id: MessageId,
 219    },
 220    Generated {
 221        text: SharedString,
 222        message_id: MessageId,
 223    },
 224}
 225
 226#[derive(Default)]
 227pub struct TotalTokenUsage {
 228    pub total: usize,
 229    pub max: usize,
 230}
 231
 232impl TotalTokenUsage {
 233    pub fn ratio(&self) -> TokenUsageRatio {
 234        #[cfg(debug_assertions)]
 235        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
 236            .unwrap_or("0.8".to_string())
 237            .parse()
 238            .unwrap();
 239        #[cfg(not(debug_assertions))]
 240        let warning_threshold: f32 = 0.8;
 241
 242        if self.total >= self.max {
 243            TokenUsageRatio::Exceeded
 244        } else if self.total as f32 / self.max as f32 >= warning_threshold {
 245            TokenUsageRatio::Warning
 246        } else {
 247            TokenUsageRatio::Normal
 248        }
 249    }
 250
 251    pub fn add(&self, tokens: usize) -> TotalTokenUsage {
 252        TotalTokenUsage {
 253            total: self.total + tokens,
 254            max: self.max,
 255        }
 256    }
 257}
 258
 259#[derive(Debug, Default, PartialEq, Eq)]
 260pub enum TokenUsageRatio {
 261    #[default]
 262    Normal,
 263    Warning,
 264    Exceeded,
 265}
 266
 267/// A thread of conversation with the LLM.
 268pub struct Thread {
 269    id: ThreadId,
 270    updated_at: DateTime<Utc>,
 271    summary: Option<SharedString>,
 272    pending_summary: Task<Option<()>>,
 273    detailed_summary_state: DetailedSummaryState,
 274    messages: Vec<Message>,
 275    next_message_id: MessageId,
 276    context: BTreeMap<ContextId, AssistantContext>,
 277    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 278    project_context: SharedProjectContext,
 279    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 280    completion_count: usize,
 281    pending_completions: Vec<PendingCompletion>,
 282    project: Entity<Project>,
 283    prompt_builder: Arc<PromptBuilder>,
 284    tools: Entity<ToolWorkingSet>,
 285    tool_use: ToolUseState,
 286    action_log: Entity<ActionLog>,
 287    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 288    pending_checkpoint: Option<ThreadCheckpoint>,
 289    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 290    request_token_usage: Vec<TokenUsage>,
 291    cumulative_token_usage: TokenUsage,
 292    exceeded_window_error: Option<ExceededWindowError>,
 293    feedback: Option<ThreadFeedback>,
 294    message_feedback: HashMap<MessageId, ThreadFeedback>,
 295    last_auto_capture_at: Option<Instant>,
 296}
 297
 298#[derive(Debug, Clone, Serialize, Deserialize)]
 299pub struct ExceededWindowError {
 300    /// Model used when last message exceeded context window
 301    model_id: LanguageModelId,
 302    /// Token count including last message
 303    token_count: usize,
 304}
 305
 306impl Thread {
 307    pub fn new(
 308        project: Entity<Project>,
 309        tools: Entity<ToolWorkingSet>,
 310        prompt_builder: Arc<PromptBuilder>,
 311        system_prompt: SharedProjectContext,
 312        cx: &mut Context<Self>,
 313    ) -> Self {
 314        Self {
 315            id: ThreadId::new(),
 316            updated_at: Utc::now(),
 317            summary: None,
 318            pending_summary: Task::ready(None),
 319            detailed_summary_state: DetailedSummaryState::NotGenerated,
 320            messages: Vec::new(),
 321            next_message_id: MessageId(0),
 322            context: BTreeMap::default(),
 323            context_by_message: HashMap::default(),
 324            project_context: system_prompt,
 325            checkpoints_by_message: HashMap::default(),
 326            completion_count: 0,
 327            pending_completions: Vec::new(),
 328            project: project.clone(),
 329            prompt_builder,
 330            tools: tools.clone(),
 331            last_restore_checkpoint: None,
 332            pending_checkpoint: None,
 333            tool_use: ToolUseState::new(tools.clone()),
 334            action_log: cx.new(|_| ActionLog::new(project.clone())),
 335            initial_project_snapshot: {
 336                let project_snapshot = Self::project_snapshot(project, cx);
 337                cx.foreground_executor()
 338                    .spawn(async move { Some(project_snapshot.await) })
 339                    .shared()
 340            },
 341            request_token_usage: Vec::new(),
 342            cumulative_token_usage: TokenUsage::default(),
 343            exceeded_window_error: None,
 344            feedback: None,
 345            message_feedback: HashMap::default(),
 346            last_auto_capture_at: None,
 347        }
 348    }
 349
 350    pub fn deserialize(
 351        id: ThreadId,
 352        serialized: SerializedThread,
 353        project: Entity<Project>,
 354        tools: Entity<ToolWorkingSet>,
 355        prompt_builder: Arc<PromptBuilder>,
 356        project_context: SharedProjectContext,
 357        cx: &mut Context<Self>,
 358    ) -> Self {
 359        let next_message_id = MessageId(
 360            serialized
 361                .messages
 362                .last()
 363                .map(|message| message.id.0 + 1)
 364                .unwrap_or(0),
 365        );
 366        let tool_use =
 367            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 368
 369        Self {
 370            id,
 371            updated_at: serialized.updated_at,
 372            summary: Some(serialized.summary),
 373            pending_summary: Task::ready(None),
 374            detailed_summary_state: serialized.detailed_summary_state,
 375            messages: serialized
 376                .messages
 377                .into_iter()
 378                .map(|message| Message {
 379                    id: message.id,
 380                    role: message.role,
 381                    segments: message
 382                        .segments
 383                        .into_iter()
 384                        .map(|segment| match segment {
 385                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 386                            SerializedMessageSegment::Thinking { text } => {
 387                                MessageSegment::Thinking(text)
 388                            }
 389                        })
 390                        .collect(),
 391                    context: message.context,
 392                })
 393                .collect(),
 394            next_message_id,
 395            context: BTreeMap::default(),
 396            context_by_message: HashMap::default(),
 397            project_context,
 398            checkpoints_by_message: HashMap::default(),
 399            completion_count: 0,
 400            pending_completions: Vec::new(),
 401            last_restore_checkpoint: None,
 402            pending_checkpoint: None,
 403            project: project.clone(),
 404            prompt_builder,
 405            tools,
 406            tool_use,
 407            action_log: cx.new(|_| ActionLog::new(project)),
 408            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 409            request_token_usage: serialized.request_token_usage,
 410            cumulative_token_usage: serialized.cumulative_token_usage,
 411            exceeded_window_error: None,
 412            feedback: None,
 413            message_feedback: HashMap::default(),
 414            last_auto_capture_at: None,
 415        }
 416    }
 417
 418    pub fn id(&self) -> &ThreadId {
 419        &self.id
 420    }
 421
 422    pub fn is_empty(&self) -> bool {
 423        self.messages.is_empty()
 424    }
 425
 426    pub fn updated_at(&self) -> DateTime<Utc> {
 427        self.updated_at
 428    }
 429
 430    pub fn touch_updated_at(&mut self) {
 431        self.updated_at = Utc::now();
 432    }
 433
 434    pub fn summary(&self) -> Option<SharedString> {
 435        self.summary.clone()
 436    }
 437
 438    pub fn project_context(&self) -> SharedProjectContext {
 439        self.project_context.clone()
 440    }
 441
 442    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 443
 444    pub fn summary_or_default(&self) -> SharedString {
 445        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 446    }
 447
 448    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 449        let Some(current_summary) = &self.summary else {
 450            // Don't allow setting summary until generated
 451            return;
 452        };
 453
 454        let mut new_summary = new_summary.into();
 455
 456        if new_summary.is_empty() {
 457            new_summary = Self::DEFAULT_SUMMARY;
 458        }
 459
 460        if current_summary != &new_summary {
 461            self.summary = Some(new_summary);
 462            cx.emit(ThreadEvent::SummaryChanged);
 463        }
 464    }
 465
 466    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 467        self.latest_detailed_summary()
 468            .unwrap_or_else(|| self.text().into())
 469    }
 470
 471    fn latest_detailed_summary(&self) -> Option<SharedString> {
 472        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 473            Some(text.clone())
 474        } else {
 475            None
 476        }
 477    }
 478
 479    pub fn message(&self, id: MessageId) -> Option<&Message> {
 480        self.messages.iter().find(|message| message.id == id)
 481    }
 482
 483    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 484        self.messages.iter()
 485    }
 486
 487    pub fn is_generating(&self) -> bool {
 488        !self.pending_completions.is_empty() || !self.all_tools_finished()
 489    }
 490
 491    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 492        &self.tools
 493    }
 494
 495    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 496        self.tool_use
 497            .pending_tool_uses()
 498            .into_iter()
 499            .find(|tool_use| &tool_use.id == id)
 500    }
 501
 502    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 503        self.tool_use
 504            .pending_tool_uses()
 505            .into_iter()
 506            .filter(|tool_use| tool_use.status.needs_confirmation())
 507    }
 508
 509    pub fn has_pending_tool_uses(&self) -> bool {
 510        !self.tool_use.pending_tool_uses().is_empty()
 511    }
 512
 513    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 514        self.checkpoints_by_message.get(&id).cloned()
 515    }
 516
 517    pub fn restore_checkpoint(
 518        &mut self,
 519        checkpoint: ThreadCheckpoint,
 520        cx: &mut Context<Self>,
 521    ) -> Task<Result<()>> {
 522        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 523            message_id: checkpoint.message_id,
 524        });
 525        cx.emit(ThreadEvent::CheckpointChanged);
 526        cx.notify();
 527
 528        let git_store = self.project().read(cx).git_store().clone();
 529        let restore = git_store.update(cx, |git_store, cx| {
 530            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 531        });
 532
 533        cx.spawn(async move |this, cx| {
 534            let result = restore.await;
 535            this.update(cx, |this, cx| {
 536                if let Err(err) = result.as_ref() {
 537                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 538                        message_id: checkpoint.message_id,
 539                        error: err.to_string(),
 540                    });
 541                } else {
 542                    this.truncate(checkpoint.message_id, cx);
 543                    this.last_restore_checkpoint = None;
 544                }
 545                this.pending_checkpoint = None;
 546                cx.emit(ThreadEvent::CheckpointChanged);
 547                cx.notify();
 548            })?;
 549            result
 550        })
 551    }
 552
 553    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 554        let pending_checkpoint = if self.is_generating() {
 555            return;
 556        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 557            checkpoint
 558        } else {
 559            return;
 560        };
 561
 562        let git_store = self.project.read(cx).git_store().clone();
 563        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 564        cx.spawn(async move |this, cx| match final_checkpoint.await {
 565            Ok(final_checkpoint) => {
 566                let equal = git_store
 567                    .update(cx, |store, cx| {
 568                        store.compare_checkpoints(
 569                            pending_checkpoint.git_checkpoint.clone(),
 570                            final_checkpoint.clone(),
 571                            cx,
 572                        )
 573                    })?
 574                    .await
 575                    .unwrap_or(false);
 576
 577                if equal {
 578                    git_store
 579                        .update(cx, |store, cx| {
 580                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 581                        })?
 582                        .detach();
 583                } else {
 584                    this.update(cx, |this, cx| {
 585                        this.insert_checkpoint(pending_checkpoint, cx)
 586                    })?;
 587                }
 588
 589                git_store
 590                    .update(cx, |store, cx| {
 591                        store.delete_checkpoint(final_checkpoint, cx)
 592                    })?
 593                    .detach();
 594
 595                Ok(())
 596            }
 597            Err(_) => this.update(cx, |this, cx| {
 598                this.insert_checkpoint(pending_checkpoint, cx)
 599            }),
 600        })
 601        .detach();
 602    }
 603
 604    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 605        self.checkpoints_by_message
 606            .insert(checkpoint.message_id, checkpoint);
 607        cx.emit(ThreadEvent::CheckpointChanged);
 608        cx.notify();
 609    }
 610
 611    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 612        self.last_restore_checkpoint.as_ref()
 613    }
 614
 615    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 616        let Some(message_ix) = self
 617            .messages
 618            .iter()
 619            .rposition(|message| message.id == message_id)
 620        else {
 621            return;
 622        };
 623        for deleted_message in self.messages.drain(message_ix..) {
 624            self.context_by_message.remove(&deleted_message.id);
 625            self.checkpoints_by_message.remove(&deleted_message.id);
 626        }
 627        cx.notify();
 628    }
 629
 630    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 631        self.context_by_message
 632            .get(&id)
 633            .into_iter()
 634            .flat_map(|context| {
 635                context
 636                    .iter()
 637                    .filter_map(|context_id| self.context.get(&context_id))
 638            })
 639    }
 640
 641    /// Returns whether all of the tool uses have finished running.
 642    pub fn all_tools_finished(&self) -> bool {
 643        // If the only pending tool uses left are the ones with errors, then
 644        // that means that we've finished running all of the pending tools.
 645        self.tool_use
 646            .pending_tool_uses()
 647            .iter()
 648            .all(|tool_use| tool_use.status.is_error())
 649    }
 650
 651    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 652        self.tool_use.tool_uses_for_message(id, cx)
 653    }
 654
 655    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 656        self.tool_use.tool_results_for_message(id)
 657    }
 658
 659    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 660        self.tool_use.tool_result(id)
 661    }
 662
 663    pub fn output_for_tool(&self, id: &LanguageModelToolUseId) -> Option<&Arc<str>> {
 664        Some(&self.tool_use.tool_result(id)?.content)
 665    }
 666
 667    pub fn card_for_tool(&self, id: &LanguageModelToolUseId) -> Option<AnyToolCard> {
 668        self.tool_use.tool_result_card(id).cloned()
 669    }
 670
 671    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 672        self.tool_use.message_has_tool_results(message_id)
 673    }
 674
 675    /// Filter out contexts that have already been included in previous messages
 676    pub fn filter_new_context<'a>(
 677        &self,
 678        context: impl Iterator<Item = &'a AssistantContext>,
 679    ) -> impl Iterator<Item = &'a AssistantContext> {
 680        context.filter(|ctx| self.is_context_new(ctx))
 681    }
 682
 683    fn is_context_new(&self, context: &AssistantContext) -> bool {
 684        !self.context.contains_key(&context.id())
 685    }
 686
 687    pub fn insert_user_message(
 688        &mut self,
 689        text: impl Into<String>,
 690        context: Vec<AssistantContext>,
 691        git_checkpoint: Option<GitStoreCheckpoint>,
 692        cx: &mut Context<Self>,
 693    ) -> MessageId {
 694        let text = text.into();
 695
 696        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 697
 698        let new_context: Vec<_> = context
 699            .into_iter()
 700            .filter(|ctx| self.is_context_new(ctx))
 701            .collect();
 702
 703        if !new_context.is_empty() {
 704            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 705                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 706                    message.context = context_string;
 707                }
 708            }
 709
 710            self.action_log.update(cx, |log, cx| {
 711                // Track all buffers added as context
 712                for ctx in &new_context {
 713                    match ctx {
 714                        AssistantContext::File(file_ctx) => {
 715                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 716                        }
 717                        AssistantContext::Directory(dir_ctx) => {
 718                            for context_buffer in &dir_ctx.context_buffers {
 719                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 720                            }
 721                        }
 722                        AssistantContext::Symbol(symbol_ctx) => {
 723                            log.buffer_added_as_context(
 724                                symbol_ctx.context_symbol.buffer.clone(),
 725                                cx,
 726                            );
 727                        }
 728                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 729                    }
 730                }
 731            });
 732        }
 733
 734        let context_ids = new_context
 735            .iter()
 736            .map(|context| context.id())
 737            .collect::<Vec<_>>();
 738        self.context.extend(
 739            new_context
 740                .into_iter()
 741                .map(|context| (context.id(), context)),
 742        );
 743        self.context_by_message.insert(message_id, context_ids);
 744
 745        if let Some(git_checkpoint) = git_checkpoint {
 746            self.pending_checkpoint = Some(ThreadCheckpoint {
 747                message_id,
 748                git_checkpoint,
 749            });
 750        }
 751
 752        self.auto_capture_telemetry(cx);
 753
 754        message_id
 755    }
 756
 757    pub fn insert_message(
 758        &mut self,
 759        role: Role,
 760        segments: Vec<MessageSegment>,
 761        cx: &mut Context<Self>,
 762    ) -> MessageId {
 763        let id = self.next_message_id.post_inc();
 764        self.messages.push(Message {
 765            id,
 766            role,
 767            segments,
 768            context: String::new(),
 769        });
 770        self.touch_updated_at();
 771        cx.emit(ThreadEvent::MessageAdded(id));
 772        id
 773    }
 774
 775    pub fn edit_message(
 776        &mut self,
 777        id: MessageId,
 778        new_role: Role,
 779        new_segments: Vec<MessageSegment>,
 780        cx: &mut Context<Self>,
 781    ) -> bool {
 782        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 783            return false;
 784        };
 785        message.role = new_role;
 786        message.segments = new_segments;
 787        self.touch_updated_at();
 788        cx.emit(ThreadEvent::MessageEdited(id));
 789        true
 790    }
 791
 792    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 793        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 794            return false;
 795        };
 796        self.messages.remove(index);
 797        self.context_by_message.remove(&id);
 798        self.touch_updated_at();
 799        cx.emit(ThreadEvent::MessageDeleted(id));
 800        true
 801    }
 802
 803    /// Returns the representation of this [`Thread`] in a textual form.
 804    ///
 805    /// This is the representation we use when attaching a thread as context to another thread.
 806    pub fn text(&self) -> String {
 807        let mut text = String::new();
 808
 809        for message in &self.messages {
 810            text.push_str(match message.role {
 811                language_model::Role::User => "User:",
 812                language_model::Role::Assistant => "Assistant:",
 813                language_model::Role::System => "System:",
 814            });
 815            text.push('\n');
 816
 817            for segment in &message.segments {
 818                match segment {
 819                    MessageSegment::Text(content) => text.push_str(content),
 820                    MessageSegment::Thinking(content) => {
 821                        text.push_str(&format!("<think>{}</think>", content))
 822                    }
 823                }
 824            }
 825            text.push('\n');
 826        }
 827
 828        text
 829    }
 830
 831    /// Serializes this thread into a format for storage or telemetry.
 832    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 833        let initial_project_snapshot = self.initial_project_snapshot.clone();
 834        cx.spawn(async move |this, cx| {
 835            let initial_project_snapshot = initial_project_snapshot.await;
 836            this.read_with(cx, |this, cx| SerializedThread {
 837                version: SerializedThread::VERSION.to_string(),
 838                summary: this.summary_or_default(),
 839                updated_at: this.updated_at(),
 840                messages: this
 841                    .messages()
 842                    .map(|message| SerializedMessage {
 843                        id: message.id,
 844                        role: message.role,
 845                        segments: message
 846                            .segments
 847                            .iter()
 848                            .map(|segment| match segment {
 849                                MessageSegment::Text(text) => {
 850                                    SerializedMessageSegment::Text { text: text.clone() }
 851                                }
 852                                MessageSegment::Thinking(text) => {
 853                                    SerializedMessageSegment::Thinking { text: text.clone() }
 854                                }
 855                            })
 856                            .collect(),
 857                        tool_uses: this
 858                            .tool_uses_for_message(message.id, cx)
 859                            .into_iter()
 860                            .map(|tool_use| SerializedToolUse {
 861                                id: tool_use.id,
 862                                name: tool_use.name,
 863                                input: tool_use.input,
 864                            })
 865                            .collect(),
 866                        tool_results: this
 867                            .tool_results_for_message(message.id)
 868                            .into_iter()
 869                            .map(|tool_result| SerializedToolResult {
 870                                tool_use_id: tool_result.tool_use_id.clone(),
 871                                is_error: tool_result.is_error,
 872                                content: tool_result.content.clone(),
 873                            })
 874                            .collect(),
 875                        context: message.context.clone(),
 876                    })
 877                    .collect(),
 878                initial_project_snapshot,
 879                cumulative_token_usage: this.cumulative_token_usage,
 880                request_token_usage: this.request_token_usage.clone(),
 881                detailed_summary_state: this.detailed_summary_state.clone(),
 882                exceeded_window_error: this.exceeded_window_error.clone(),
 883            })
 884        })
 885    }
 886
 887    pub fn send_to_model(
 888        &mut self,
 889        model: Arc<dyn LanguageModel>,
 890        request_kind: RequestKind,
 891        cx: &mut Context<Self>,
 892    ) {
 893        let mut request = self.to_completion_request(request_kind, cx);
 894        if model.supports_tools() {
 895            request.tools = {
 896                let mut tools = Vec::new();
 897                tools.extend(
 898                    self.tools()
 899                        .read(cx)
 900                        .enabled_tools(cx)
 901                        .into_iter()
 902                        .filter_map(|tool| {
 903                            // Skip tools that cannot be supported
 904                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 905                            Some(LanguageModelRequestTool {
 906                                name: tool.name(),
 907                                description: tool.description(),
 908                                input_schema,
 909                            })
 910                        }),
 911                );
 912
 913                tools
 914            };
 915        }
 916
 917        self.stream_completion(request, model, cx);
 918    }
 919
 920    pub fn used_tools_since_last_user_message(&self) -> bool {
 921        for message in self.messages.iter().rev() {
 922            if self.tool_use.message_has_tool_results(message.id) {
 923                return true;
 924            } else if message.role == Role::User {
 925                return false;
 926            }
 927        }
 928
 929        false
 930    }
 931
 932    pub fn to_completion_request(
 933        &self,
 934        request_kind: RequestKind,
 935        cx: &App,
 936    ) -> LanguageModelRequest {
 937        let mut request = LanguageModelRequest {
 938            messages: vec![],
 939            tools: Vec::new(),
 940            stop: Vec::new(),
 941            temperature: None,
 942        };
 943
 944        if let Some(project_context) = self.project_context.borrow().as_ref() {
 945            if let Some(system_prompt) = self
 946                .prompt_builder
 947                .generate_assistant_system_prompt(project_context)
 948                .context("failed to generate assistant system prompt")
 949                .log_err()
 950            {
 951                request.messages.push(LanguageModelRequestMessage {
 952                    role: Role::System,
 953                    content: vec![MessageContent::Text(system_prompt)],
 954                    cache: true,
 955                });
 956            }
 957        } else {
 958            log::error!("project_context not set.")
 959        }
 960
 961        for message in &self.messages {
 962            let mut request_message = LanguageModelRequestMessage {
 963                role: message.role,
 964                content: Vec::new(),
 965                cache: false,
 966            };
 967
 968            match request_kind {
 969                RequestKind::Chat => {
 970                    self.tool_use
 971                        .attach_tool_results(message.id, &mut request_message);
 972                }
 973                RequestKind::Summarize => {
 974                    // We don't care about tool use during summarization.
 975                    if self.tool_use.message_has_tool_results(message.id) {
 976                        continue;
 977                    }
 978                }
 979            }
 980
 981            if !message.segments.is_empty() {
 982                request_message
 983                    .content
 984                    .push(MessageContent::Text(message.to_string()));
 985            }
 986
 987            match request_kind {
 988                RequestKind::Chat => {
 989                    self.tool_use
 990                        .attach_tool_uses(message.id, &mut request_message);
 991                }
 992                RequestKind::Summarize => {
 993                    // We don't care about tool use during summarization.
 994                }
 995            };
 996
 997            request.messages.push(request_message);
 998        }
 999
1000        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1001        if let Some(last) = request.messages.last_mut() {
1002            last.cache = true;
1003        }
1004
1005        self.attached_tracked_files_state(&mut request.messages, cx);
1006
1007        request
1008    }
1009
1010    fn attached_tracked_files_state(
1011        &self,
1012        messages: &mut Vec<LanguageModelRequestMessage>,
1013        cx: &App,
1014    ) {
1015        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1016
1017        let mut stale_message = String::new();
1018
1019        let action_log = self.action_log.read(cx);
1020
1021        for stale_file in action_log.stale_buffers(cx) {
1022            let Some(file) = stale_file.read(cx).file() else {
1023                continue;
1024            };
1025
1026            if stale_message.is_empty() {
1027                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1028            }
1029
1030            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1031        }
1032
1033        let mut content = Vec::with_capacity(2);
1034
1035        if !stale_message.is_empty() {
1036            content.push(stale_message.into());
1037        }
1038
1039        if action_log.has_edited_files_since_project_diagnostics_check() {
1040            content.push(
1041                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1042                and fix all errors AND warnings you introduced! \
1043                DO NOT mention you're going to do this until you're done."
1044                    .into(),
1045            );
1046        }
1047
1048        if !content.is_empty() {
1049            let context_message = LanguageModelRequestMessage {
1050                role: Role::User,
1051                content,
1052                cache: false,
1053            };
1054
1055            messages.push(context_message);
1056        }
1057    }
1058
1059    pub fn stream_completion(
1060        &mut self,
1061        request: LanguageModelRequest,
1062        model: Arc<dyn LanguageModel>,
1063        cx: &mut Context<Self>,
1064    ) {
1065        let pending_completion_id = post_inc(&mut self.completion_count);
1066        let task = cx.spawn(async move |thread, cx| {
1067            let stream = model.stream_completion(request, &cx);
1068            let initial_token_usage =
1069                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1070            let stream_completion = async {
1071                let mut events = stream.await?;
1072                let mut stop_reason = StopReason::EndTurn;
1073                let mut current_token_usage = TokenUsage::default();
1074
1075                while let Some(event) = events.next().await {
1076                    let event = event?;
1077
1078                    thread.update(cx, |thread, cx| {
1079                        match event {
1080                            LanguageModelCompletionEvent::StartMessage { .. } => {
1081                                thread.insert_message(
1082                                    Role::Assistant,
1083                                    vec![MessageSegment::Text(String::new())],
1084                                    cx,
1085                                );
1086                            }
1087                            LanguageModelCompletionEvent::Stop(reason) => {
1088                                stop_reason = reason;
1089                            }
1090                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1091                                thread.update_token_usage_at_last_message(token_usage);
1092                                thread.cumulative_token_usage = thread.cumulative_token_usage
1093                                    + token_usage
1094                                    - current_token_usage;
1095                                current_token_usage = token_usage;
1096                            }
1097                            LanguageModelCompletionEvent::Text(chunk) => {
1098                                if let Some(last_message) = thread.messages.last_mut() {
1099                                    if last_message.role == Role::Assistant {
1100                                        last_message.push_text(&chunk);
1101                                        cx.emit(ThreadEvent::StreamedAssistantText(
1102                                            last_message.id,
1103                                            chunk,
1104                                        ));
1105                                    } else {
1106                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1107                                        // of a new Assistant response.
1108                                        //
1109                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1110                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1111                                        thread.insert_message(
1112                                            Role::Assistant,
1113                                            vec![MessageSegment::Text(chunk.to_string())],
1114                                            cx,
1115                                        );
1116                                    };
1117                                }
1118                            }
1119                            LanguageModelCompletionEvent::Thinking(chunk) => {
1120                                if let Some(last_message) = thread.messages.last_mut() {
1121                                    if last_message.role == Role::Assistant {
1122                                        last_message.push_thinking(&chunk);
1123                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1124                                            last_message.id,
1125                                            chunk,
1126                                        ));
1127                                    } else {
1128                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1129                                        // of a new Assistant response.
1130                                        //
1131                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1132                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1133                                        thread.insert_message(
1134                                            Role::Assistant,
1135                                            vec![MessageSegment::Thinking(chunk.to_string())],
1136                                            cx,
1137                                        );
1138                                    };
1139                                }
1140                            }
1141                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1142                                let last_assistant_message_id = thread
1143                                    .messages
1144                                    .iter_mut()
1145                                    .rfind(|message| message.role == Role::Assistant)
1146                                    .map(|message| message.id)
1147                                    .unwrap_or_else(|| {
1148                                        thread.insert_message(Role::Assistant, vec![], cx)
1149                                    });
1150
1151                                thread.tool_use.request_tool_use(
1152                                    last_assistant_message_id,
1153                                    tool_use,
1154                                    cx,
1155                                );
1156                            }
1157                        }
1158
1159                        thread.touch_updated_at();
1160                        cx.emit(ThreadEvent::StreamedCompletion);
1161                        cx.notify();
1162
1163                        thread.auto_capture_telemetry(cx);
1164                    })?;
1165
1166                    smol::future::yield_now().await;
1167                }
1168
1169                thread.update(cx, |thread, cx| {
1170                    thread
1171                        .pending_completions
1172                        .retain(|completion| completion.id != pending_completion_id);
1173
1174                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1175                        thread.summarize(cx);
1176                    }
1177                })?;
1178
1179                anyhow::Ok(stop_reason)
1180            };
1181
1182            let result = stream_completion.await;
1183
1184            thread
1185                .update(cx, |thread, cx| {
1186                    thread.finalize_pending_checkpoint(cx);
1187                    match result.as_ref() {
1188                        Ok(stop_reason) => match stop_reason {
1189                            StopReason::ToolUse => {
1190                                let tool_uses = thread.use_pending_tools(cx);
1191                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1192                            }
1193                            StopReason::EndTurn => {}
1194                            StopReason::MaxTokens => {}
1195                        },
1196                        Err(error) => {
1197                            if error.is::<PaymentRequiredError>() {
1198                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1199                            } else if error.is::<MaxMonthlySpendReachedError>() {
1200                                cx.emit(ThreadEvent::ShowError(
1201                                    ThreadError::MaxMonthlySpendReached,
1202                                ));
1203                            } else if let Some(error) =
1204                                error.downcast_ref::<ModelRequestLimitReachedError>()
1205                            {
1206                                cx.emit(ThreadEvent::ShowError(
1207                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1208                                ));
1209                            } else if let Some(known_error) =
1210                                error.downcast_ref::<LanguageModelKnownError>()
1211                            {
1212                                match known_error {
1213                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1214                                        tokens,
1215                                    } => {
1216                                        thread.exceeded_window_error = Some(ExceededWindowError {
1217                                            model_id: model.id(),
1218                                            token_count: *tokens,
1219                                        });
1220                                        cx.notify();
1221                                    }
1222                                }
1223                            } else {
1224                                let error_message = error
1225                                    .chain()
1226                                    .map(|err| err.to_string())
1227                                    .collect::<Vec<_>>()
1228                                    .join("\n");
1229                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1230                                    header: "Error interacting with language model".into(),
1231                                    message: SharedString::from(error_message.clone()),
1232                                }));
1233                            }
1234
1235                            thread.cancel_last_completion(cx);
1236                        }
1237                    }
1238                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1239
1240                    thread.auto_capture_telemetry(cx);
1241
1242                    if let Ok(initial_usage) = initial_token_usage {
1243                        let usage = thread.cumulative_token_usage - initial_usage;
1244
1245                        telemetry::event!(
1246                            "Assistant Thread Completion",
1247                            thread_id = thread.id().to_string(),
1248                            model = model.telemetry_id(),
1249                            model_provider = model.provider_id().to_string(),
1250                            input_tokens = usage.input_tokens,
1251                            output_tokens = usage.output_tokens,
1252                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1253                            cache_read_input_tokens = usage.cache_read_input_tokens,
1254                        );
1255                    }
1256                })
1257                .ok();
1258        });
1259
1260        self.pending_completions.push(PendingCompletion {
1261            id: pending_completion_id,
1262            _task: task,
1263        });
1264    }
1265
1266    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1267        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1268            return;
1269        };
1270
1271        if !model.provider.is_authenticated(cx) {
1272            return;
1273        }
1274
1275        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1276        request.messages.push(LanguageModelRequestMessage {
1277            role: Role::User,
1278            content: vec![
1279                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1280                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1281                 If the conversation is about a specific subject, include it in the title. \
1282                 Be descriptive. DO NOT speak in the first person."
1283                    .into(),
1284            ],
1285            cache: false,
1286        });
1287
1288        self.pending_summary = cx.spawn(async move |this, cx| {
1289            async move {
1290                let stream = model.model.stream_completion_text(request, &cx);
1291                let mut messages = stream.await?;
1292
1293                let mut new_summary = String::new();
1294                while let Some(message) = messages.stream.next().await {
1295                    let text = message?;
1296                    let mut lines = text.lines();
1297                    new_summary.extend(lines.next());
1298
1299                    // Stop if the LLM generated multiple lines.
1300                    if lines.next().is_some() {
1301                        break;
1302                    }
1303                }
1304
1305                this.update(cx, |this, cx| {
1306                    if !new_summary.is_empty() {
1307                        this.summary = Some(new_summary.into());
1308                    }
1309
1310                    cx.emit(ThreadEvent::SummaryGenerated);
1311                })?;
1312
1313                anyhow::Ok(())
1314            }
1315            .log_err()
1316            .await
1317        });
1318    }
1319
1320    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1321        let last_message_id = self.messages.last().map(|message| message.id)?;
1322
1323        match &self.detailed_summary_state {
1324            DetailedSummaryState::Generating { message_id, .. }
1325            | DetailedSummaryState::Generated { message_id, .. }
1326                if *message_id == last_message_id =>
1327            {
1328                // Already up-to-date
1329                return None;
1330            }
1331            _ => {}
1332        }
1333
1334        let ConfiguredModel { model, provider } =
1335            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1336
1337        if !provider.is_authenticated(cx) {
1338            return None;
1339        }
1340
1341        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1342
1343        request.messages.push(LanguageModelRequestMessage {
1344            role: Role::User,
1345            content: vec![
1346                "Generate a detailed summary of this conversation. Include:\n\
1347                1. A brief overview of what was discussed\n\
1348                2. Key facts or information discovered\n\
1349                3. Outcomes or conclusions reached\n\
1350                4. Any action items or next steps if any\n\
1351                Format it in Markdown with headings and bullet points."
1352                    .into(),
1353            ],
1354            cache: false,
1355        });
1356
1357        let task = cx.spawn(async move |thread, cx| {
1358            let stream = model.stream_completion_text(request, &cx);
1359            let Some(mut messages) = stream.await.log_err() else {
1360                thread
1361                    .update(cx, |this, _cx| {
1362                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1363                    })
1364                    .log_err();
1365
1366                return;
1367            };
1368
1369            let mut new_detailed_summary = String::new();
1370
1371            while let Some(chunk) = messages.stream.next().await {
1372                if let Some(chunk) = chunk.log_err() {
1373                    new_detailed_summary.push_str(&chunk);
1374                }
1375            }
1376
1377            thread
1378                .update(cx, |this, _cx| {
1379                    this.detailed_summary_state = DetailedSummaryState::Generated {
1380                        text: new_detailed_summary.into(),
1381                        message_id: last_message_id,
1382                    };
1383                })
1384                .log_err();
1385        });
1386
1387        self.detailed_summary_state = DetailedSummaryState::Generating {
1388            message_id: last_message_id,
1389        };
1390
1391        Some(task)
1392    }
1393
1394    pub fn is_generating_detailed_summary(&self) -> bool {
1395        matches!(
1396            self.detailed_summary_state,
1397            DetailedSummaryState::Generating { .. }
1398        )
1399    }
1400
1401    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1402        self.auto_capture_telemetry(cx);
1403        let request = self.to_completion_request(RequestKind::Chat, cx);
1404        let messages = Arc::new(request.messages);
1405        let pending_tool_uses = self
1406            .tool_use
1407            .pending_tool_uses()
1408            .into_iter()
1409            .filter(|tool_use| tool_use.status.is_idle())
1410            .cloned()
1411            .collect::<Vec<_>>();
1412
1413        for tool_use in pending_tool_uses.iter() {
1414            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1415                if tool.needs_confirmation(&tool_use.input, cx)
1416                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1417                {
1418                    self.tool_use.confirm_tool_use(
1419                        tool_use.id.clone(),
1420                        tool_use.ui_text.clone(),
1421                        tool_use.input.clone(),
1422                        messages.clone(),
1423                        tool,
1424                    );
1425                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1426                } else {
1427                    self.run_tool(
1428                        tool_use.id.clone(),
1429                        tool_use.ui_text.clone(),
1430                        tool_use.input.clone(),
1431                        &messages,
1432                        tool,
1433                        cx,
1434                    );
1435                }
1436            }
1437        }
1438
1439        pending_tool_uses
1440    }
1441
1442    pub fn run_tool(
1443        &mut self,
1444        tool_use_id: LanguageModelToolUseId,
1445        ui_text: impl Into<SharedString>,
1446        input: serde_json::Value,
1447        messages: &[LanguageModelRequestMessage],
1448        tool: Arc<dyn Tool>,
1449        cx: &mut Context<Thread>,
1450    ) {
1451        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1452        self.tool_use
1453            .run_pending_tool(tool_use_id, ui_text.into(), task);
1454    }
1455
1456    fn spawn_tool_use(
1457        &mut self,
1458        tool_use_id: LanguageModelToolUseId,
1459        messages: &[LanguageModelRequestMessage],
1460        input: serde_json::Value,
1461        tool: Arc<dyn Tool>,
1462        cx: &mut Context<Thread>,
1463    ) -> Task<()> {
1464        let tool_name: Arc<str> = tool.name().into();
1465
1466        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1467            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1468        } else {
1469            tool.run(
1470                input,
1471                messages,
1472                self.project.clone(),
1473                self.action_log.clone(),
1474                cx,
1475            )
1476        };
1477
1478        // Store the card separately if it exists
1479        if let Some(card) = tool_result.card.clone() {
1480            self.tool_use
1481                .insert_tool_result_card(tool_use_id.clone(), card);
1482        }
1483
1484        cx.spawn({
1485            async move |thread: WeakEntity<Thread>, cx| {
1486                let output = tool_result.output.await;
1487
1488                thread
1489                    .update(cx, |thread, cx| {
1490                        let pending_tool_use = thread.tool_use.insert_tool_output(
1491                            tool_use_id.clone(),
1492                            tool_name,
1493                            output,
1494                            cx,
1495                        );
1496                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1497                    })
1498                    .ok();
1499            }
1500        })
1501    }
1502
1503    fn tool_finished(
1504        &mut self,
1505        tool_use_id: LanguageModelToolUseId,
1506        pending_tool_use: Option<PendingToolUse>,
1507        canceled: bool,
1508        cx: &mut Context<Self>,
1509    ) {
1510        if self.all_tools_finished() {
1511            let model_registry = LanguageModelRegistry::read_global(cx);
1512            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1513                self.attach_tool_results(cx);
1514                if !canceled {
1515                    self.send_to_model(model, RequestKind::Chat, cx);
1516                }
1517            }
1518        }
1519
1520        cx.emit(ThreadEvent::ToolFinished {
1521            tool_use_id,
1522            pending_tool_use,
1523        });
1524    }
1525
1526    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1527        // Insert a user message to contain the tool results.
1528        self.insert_user_message(
1529            // TODO: Sending up a user message without any content results in the model sending back
1530            // responses that also don't have any content. We currently don't handle this case well,
1531            // so for now we provide some text to keep the model on track.
1532            "Here are the tool results.",
1533            Vec::new(),
1534            None,
1535            cx,
1536        );
1537    }
1538
1539    /// Cancels the last pending completion, if there are any pending.
1540    ///
1541    /// Returns whether a completion was canceled.
1542    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1543        let canceled = if self.pending_completions.pop().is_some() {
1544            true
1545        } else {
1546            let mut canceled = false;
1547            for pending_tool_use in self.tool_use.cancel_pending() {
1548                canceled = true;
1549                self.tool_finished(
1550                    pending_tool_use.id.clone(),
1551                    Some(pending_tool_use),
1552                    true,
1553                    cx,
1554                );
1555            }
1556            canceled
1557        };
1558        self.finalize_pending_checkpoint(cx);
1559        canceled
1560    }
1561
1562    pub fn feedback(&self) -> Option<ThreadFeedback> {
1563        self.feedback
1564    }
1565
1566    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1567        self.message_feedback.get(&message_id).copied()
1568    }
1569
1570    pub fn report_message_feedback(
1571        &mut self,
1572        message_id: MessageId,
1573        feedback: ThreadFeedback,
1574        cx: &mut Context<Self>,
1575    ) -> Task<Result<()>> {
1576        if self.message_feedback.get(&message_id) == Some(&feedback) {
1577            return Task::ready(Ok(()));
1578        }
1579
1580        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1581        let serialized_thread = self.serialize(cx);
1582        let thread_id = self.id().clone();
1583        let client = self.project.read(cx).client();
1584
1585        let enabled_tool_names: Vec<String> = self
1586            .tools()
1587            .read(cx)
1588            .enabled_tools(cx)
1589            .iter()
1590            .map(|tool| tool.name().to_string())
1591            .collect();
1592
1593        self.message_feedback.insert(message_id, feedback);
1594
1595        cx.notify();
1596
1597        let message_content = self
1598            .message(message_id)
1599            .map(|msg| msg.to_string())
1600            .unwrap_or_default();
1601
1602        cx.background_spawn(async move {
1603            let final_project_snapshot = final_project_snapshot.await;
1604            let serialized_thread = serialized_thread.await?;
1605            let thread_data =
1606                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1607
1608            let rating = match feedback {
1609                ThreadFeedback::Positive => "positive",
1610                ThreadFeedback::Negative => "negative",
1611            };
1612            telemetry::event!(
1613                "Assistant Thread Rated",
1614                rating,
1615                thread_id,
1616                enabled_tool_names,
1617                message_id = message_id.0,
1618                message_content,
1619                thread_data,
1620                final_project_snapshot
1621            );
1622            client.telemetry().flush_events();
1623
1624            Ok(())
1625        })
1626    }
1627
1628    pub fn report_feedback(
1629        &mut self,
1630        feedback: ThreadFeedback,
1631        cx: &mut Context<Self>,
1632    ) -> Task<Result<()>> {
1633        let last_assistant_message_id = self
1634            .messages
1635            .iter()
1636            .rev()
1637            .find(|msg| msg.role == Role::Assistant)
1638            .map(|msg| msg.id);
1639
1640        if let Some(message_id) = last_assistant_message_id {
1641            self.report_message_feedback(message_id, feedback, cx)
1642        } else {
1643            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1644            let serialized_thread = self.serialize(cx);
1645            let thread_id = self.id().clone();
1646            let client = self.project.read(cx).client();
1647            self.feedback = Some(feedback);
1648            cx.notify();
1649
1650            cx.background_spawn(async move {
1651                let final_project_snapshot = final_project_snapshot.await;
1652                let serialized_thread = serialized_thread.await?;
1653                let thread_data = serde_json::to_value(serialized_thread)
1654                    .unwrap_or_else(|_| serde_json::Value::Null);
1655
1656                let rating = match feedback {
1657                    ThreadFeedback::Positive => "positive",
1658                    ThreadFeedback::Negative => "negative",
1659                };
1660                telemetry::event!(
1661                    "Assistant Thread Rated",
1662                    rating,
1663                    thread_id,
1664                    thread_data,
1665                    final_project_snapshot
1666                );
1667                client.telemetry().flush_events();
1668
1669                Ok(())
1670            })
1671        }
1672    }
1673
1674    /// Create a snapshot of the current project state including git information and unsaved buffers.
1675    fn project_snapshot(
1676        project: Entity<Project>,
1677        cx: &mut Context<Self>,
1678    ) -> Task<Arc<ProjectSnapshot>> {
1679        let git_store = project.read(cx).git_store().clone();
1680        let worktree_snapshots: Vec<_> = project
1681            .read(cx)
1682            .visible_worktrees(cx)
1683            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1684            .collect();
1685
1686        cx.spawn(async move |_, cx| {
1687            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1688
1689            let mut unsaved_buffers = Vec::new();
1690            cx.update(|app_cx| {
1691                let buffer_store = project.read(app_cx).buffer_store();
1692                for buffer_handle in buffer_store.read(app_cx).buffers() {
1693                    let buffer = buffer_handle.read(app_cx);
1694                    if buffer.is_dirty() {
1695                        if let Some(file) = buffer.file() {
1696                            let path = file.path().to_string_lossy().to_string();
1697                            unsaved_buffers.push(path);
1698                        }
1699                    }
1700                }
1701            })
1702            .ok();
1703
1704            Arc::new(ProjectSnapshot {
1705                worktree_snapshots,
1706                unsaved_buffer_paths: unsaved_buffers,
1707                timestamp: Utc::now(),
1708            })
1709        })
1710    }
1711
1712    fn worktree_snapshot(
1713        worktree: Entity<project::Worktree>,
1714        git_store: Entity<GitStore>,
1715        cx: &App,
1716    ) -> Task<WorktreeSnapshot> {
1717        cx.spawn(async move |cx| {
1718            // Get worktree path and snapshot
1719            let worktree_info = cx.update(|app_cx| {
1720                let worktree = worktree.read(app_cx);
1721                let path = worktree.abs_path().to_string_lossy().to_string();
1722                let snapshot = worktree.snapshot();
1723                (path, snapshot)
1724            });
1725
1726            let Ok((worktree_path, _snapshot)) = worktree_info else {
1727                return WorktreeSnapshot {
1728                    worktree_path: String::new(),
1729                    git_state: None,
1730                };
1731            };
1732
1733            let git_state = git_store
1734                .update(cx, |git_store, cx| {
1735                    git_store
1736                        .repositories()
1737                        .values()
1738                        .find(|repo| {
1739                            repo.read(cx)
1740                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1741                                .is_some()
1742                        })
1743                        .cloned()
1744                })
1745                .ok()
1746                .flatten()
1747                .map(|repo| {
1748                    repo.update(cx, |repo, _| {
1749                        let current_branch =
1750                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1751                        repo.send_job(None, |state, _| async move {
1752                            let RepositoryState::Local { backend, .. } = state else {
1753                                return GitState {
1754                                    remote_url: None,
1755                                    head_sha: None,
1756                                    current_branch,
1757                                    diff: None,
1758                                };
1759                            };
1760
1761                            let remote_url = backend.remote_url("origin");
1762                            let head_sha = backend.head_sha();
1763                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1764
1765                            GitState {
1766                                remote_url,
1767                                head_sha,
1768                                current_branch,
1769                                diff,
1770                            }
1771                        })
1772                    })
1773                });
1774
1775            let git_state = match git_state {
1776                Some(git_state) => match git_state.ok() {
1777                    Some(git_state) => git_state.await.ok(),
1778                    None => None,
1779                },
1780                None => None,
1781            };
1782
1783            WorktreeSnapshot {
1784                worktree_path,
1785                git_state,
1786            }
1787        })
1788    }
1789
1790    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1791        let mut markdown = Vec::new();
1792
1793        if let Some(summary) = self.summary() {
1794            writeln!(markdown, "# {summary}\n")?;
1795        };
1796
1797        for message in self.messages() {
1798            writeln!(
1799                markdown,
1800                "## {role}\n",
1801                role = match message.role {
1802                    Role::User => "User",
1803                    Role::Assistant => "Assistant",
1804                    Role::System => "System",
1805                }
1806            )?;
1807
1808            if !message.context.is_empty() {
1809                writeln!(markdown, "{}", message.context)?;
1810            }
1811
1812            for segment in &message.segments {
1813                match segment {
1814                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1815                    MessageSegment::Thinking(text) => {
1816                        writeln!(markdown, "<think>{}</think>\n", text)?
1817                    }
1818                }
1819            }
1820
1821            for tool_use in self.tool_uses_for_message(message.id, cx) {
1822                writeln!(
1823                    markdown,
1824                    "**Use Tool: {} ({})**",
1825                    tool_use.name, tool_use.id
1826                )?;
1827                writeln!(markdown, "```json")?;
1828                writeln!(
1829                    markdown,
1830                    "{}",
1831                    serde_json::to_string_pretty(&tool_use.input)?
1832                )?;
1833                writeln!(markdown, "```")?;
1834            }
1835
1836            for tool_result in self.tool_results_for_message(message.id) {
1837                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1838                if tool_result.is_error {
1839                    write!(markdown, " (Error)")?;
1840                }
1841
1842                writeln!(markdown, "**\n")?;
1843                writeln!(markdown, "{}", tool_result.content)?;
1844            }
1845        }
1846
1847        Ok(String::from_utf8_lossy(&markdown).to_string())
1848    }
1849
1850    pub fn keep_edits_in_range(
1851        &mut self,
1852        buffer: Entity<language::Buffer>,
1853        buffer_range: Range<language::Anchor>,
1854        cx: &mut Context<Self>,
1855    ) {
1856        self.action_log.update(cx, |action_log, cx| {
1857            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1858        });
1859    }
1860
1861    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1862        self.action_log
1863            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1864    }
1865
1866    pub fn reject_edits_in_ranges(
1867        &mut self,
1868        buffer: Entity<language::Buffer>,
1869        buffer_ranges: Vec<Range<language::Anchor>>,
1870        cx: &mut Context<Self>,
1871    ) -> Task<Result<()>> {
1872        self.action_log.update(cx, |action_log, cx| {
1873            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1874        })
1875    }
1876
1877    pub fn action_log(&self) -> &Entity<ActionLog> {
1878        &self.action_log
1879    }
1880
1881    pub fn project(&self) -> &Entity<Project> {
1882        &self.project
1883    }
1884
1885    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1886        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1887            return;
1888        }
1889
1890        let now = Instant::now();
1891        if let Some(last) = self.last_auto_capture_at {
1892            if now.duration_since(last).as_secs() < 10 {
1893                return;
1894            }
1895        }
1896
1897        self.last_auto_capture_at = Some(now);
1898
1899        let thread_id = self.id().clone();
1900        let github_login = self
1901            .project
1902            .read(cx)
1903            .user_store()
1904            .read(cx)
1905            .current_user()
1906            .map(|user| user.github_login.clone());
1907        let client = self.project.read(cx).client().clone();
1908        let serialize_task = self.serialize(cx);
1909
1910        cx.background_executor()
1911            .spawn(async move {
1912                if let Ok(serialized_thread) = serialize_task.await {
1913                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1914                        telemetry::event!(
1915                            "Agent Thread Auto-Captured",
1916                            thread_id = thread_id.to_string(),
1917                            thread_data = thread_data,
1918                            auto_capture_reason = "tracked_user",
1919                            github_login = github_login
1920                        );
1921
1922                        client.telemetry().flush_events();
1923                    }
1924                }
1925            })
1926            .detach();
1927    }
1928
1929    pub fn cumulative_token_usage(&self) -> TokenUsage {
1930        self.cumulative_token_usage
1931    }
1932
1933    pub fn token_usage_up_to_message(&self, message_id: MessageId, cx: &App) -> TotalTokenUsage {
1934        let Some(model) = LanguageModelRegistry::read_global(cx).default_model() else {
1935            return TotalTokenUsage::default();
1936        };
1937
1938        let max = model.model.max_token_count();
1939
1940        let index = self
1941            .messages
1942            .iter()
1943            .position(|msg| msg.id == message_id)
1944            .unwrap_or(0);
1945
1946        if index == 0 {
1947            return TotalTokenUsage { total: 0, max };
1948        }
1949
1950        let token_usage = &self
1951            .request_token_usage
1952            .get(index - 1)
1953            .cloned()
1954            .unwrap_or_default();
1955
1956        TotalTokenUsage {
1957            total: token_usage.total_tokens() as usize,
1958            max,
1959        }
1960    }
1961
1962    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1963        let model_registry = LanguageModelRegistry::read_global(cx);
1964        let Some(model) = model_registry.default_model() else {
1965            return TotalTokenUsage::default();
1966        };
1967
1968        let max = model.model.max_token_count();
1969
1970        if let Some(exceeded_error) = &self.exceeded_window_error {
1971            if model.model.id() == exceeded_error.model_id {
1972                return TotalTokenUsage {
1973                    total: exceeded_error.token_count,
1974                    max,
1975                };
1976            }
1977        }
1978
1979        let total = self
1980            .token_usage_at_last_message()
1981            .unwrap_or_default()
1982            .total_tokens() as usize;
1983
1984        TotalTokenUsage { total, max }
1985    }
1986
1987    fn token_usage_at_last_message(&self) -> Option<TokenUsage> {
1988        self.request_token_usage
1989            .get(self.messages.len().saturating_sub(1))
1990            .or_else(|| self.request_token_usage.last())
1991            .cloned()
1992    }
1993
1994    fn update_token_usage_at_last_message(&mut self, token_usage: TokenUsage) {
1995        let placeholder = self.token_usage_at_last_message().unwrap_or_default();
1996        self.request_token_usage
1997            .resize(self.messages.len(), placeholder);
1998
1999        if let Some(last) = self.request_token_usage.last_mut() {
2000            *last = token_usage;
2001        }
2002    }
2003
2004    pub fn deny_tool_use(
2005        &mut self,
2006        tool_use_id: LanguageModelToolUseId,
2007        tool_name: Arc<str>,
2008        cx: &mut Context<Self>,
2009    ) {
2010        let err = Err(anyhow::anyhow!(
2011            "Permission to run tool action denied by user"
2012        ));
2013
2014        self.tool_use
2015            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
2016        self.tool_finished(tool_use_id.clone(), None, true, cx);
2017    }
2018}
2019
2020#[derive(Debug, Clone, Error)]
2021pub enum ThreadError {
2022    #[error("Payment required")]
2023    PaymentRequired,
2024    #[error("Max monthly spend reached")]
2025    MaxMonthlySpendReached,
2026    #[error("Model request limit reached")]
2027    ModelRequestLimitReached { plan: Plan },
2028    #[error("Message {header}: {message}")]
2029    Message {
2030        header: SharedString,
2031        message: SharedString,
2032    },
2033}
2034
2035#[derive(Debug, Clone)]
2036pub enum ThreadEvent {
2037    ShowError(ThreadError),
2038    StreamedCompletion,
2039    StreamedAssistantText(MessageId, String),
2040    StreamedAssistantThinking(MessageId, String),
2041    Stopped(Result<StopReason, Arc<anyhow::Error>>),
2042    MessageAdded(MessageId),
2043    MessageEdited(MessageId),
2044    MessageDeleted(MessageId),
2045    SummaryGenerated,
2046    SummaryChanged,
2047    UsePendingTools {
2048        tool_uses: Vec<PendingToolUse>,
2049    },
2050    ToolFinished {
2051        #[allow(unused)]
2052        tool_use_id: LanguageModelToolUseId,
2053        /// The pending tool use that corresponds to this tool.
2054        pending_tool_use: Option<PendingToolUse>,
2055    },
2056    CheckpointChanged,
2057    ToolConfirmationNeeded,
2058}
2059
2060impl EventEmitter<ThreadEvent> for Thread {}
2061
2062struct PendingCompletion {
2063    id: usize,
2064    _task: Task<()>,
2065}
2066
2067#[cfg(test)]
2068mod tests {
2069    use super::*;
2070    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2071    use assistant_settings::AssistantSettings;
2072    use context_server::ContextServerSettings;
2073    use editor::EditorSettings;
2074    use gpui::TestAppContext;
2075    use project::{FakeFs, Project};
2076    use prompt_store::PromptBuilder;
2077    use serde_json::json;
2078    use settings::{Settings, SettingsStore};
2079    use std::sync::Arc;
2080    use theme::ThemeSettings;
2081    use util::path;
2082    use workspace::Workspace;
2083
2084    #[gpui::test]
2085    async fn test_message_with_context(cx: &mut TestAppContext) {
2086        init_test_settings(cx);
2087
2088        let project = create_test_project(
2089            cx,
2090            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2091        )
2092        .await;
2093
2094        let (_workspace, _thread_store, thread, context_store) =
2095            setup_test_environment(cx, project.clone()).await;
2096
2097        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2098            .await
2099            .unwrap();
2100
2101        let context =
2102            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2103
2104        // Insert user message with context
2105        let message_id = thread.update(cx, |thread, cx| {
2106            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2107        });
2108
2109        // Check content and context in message object
2110        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2111
2112        // Use different path format strings based on platform for the test
2113        #[cfg(windows)]
2114        let path_part = r"test\code.rs";
2115        #[cfg(not(windows))]
2116        let path_part = "test/code.rs";
2117
2118        let expected_context = format!(
2119            r#"
2120<context>
2121The following items were attached by the user. You don't need to use other tools to read them.
2122
2123<files>
2124```rs {path_part}
2125fn main() {{
2126    println!("Hello, world!");
2127}}
2128```
2129</files>
2130</context>
2131"#
2132        );
2133
2134        assert_eq!(message.role, Role::User);
2135        assert_eq!(message.segments.len(), 1);
2136        assert_eq!(
2137            message.segments[0],
2138            MessageSegment::Text("Please explain this code".to_string())
2139        );
2140        assert_eq!(message.context, expected_context);
2141
2142        // Check message in request
2143        let request = thread.read_with(cx, |thread, cx| {
2144            thread.to_completion_request(RequestKind::Chat, cx)
2145        });
2146
2147        assert_eq!(request.messages.len(), 2);
2148        let expected_full_message = format!("{}Please explain this code", expected_context);
2149        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2150    }
2151
2152    #[gpui::test]
2153    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2154        init_test_settings(cx);
2155
2156        let project = create_test_project(
2157            cx,
2158            json!({
2159                "file1.rs": "fn function1() {}\n",
2160                "file2.rs": "fn function2() {}\n",
2161                "file3.rs": "fn function3() {}\n",
2162            }),
2163        )
2164        .await;
2165
2166        let (_, _thread_store, thread, context_store) =
2167            setup_test_environment(cx, project.clone()).await;
2168
2169        // Open files individually
2170        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2171            .await
2172            .unwrap();
2173        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2174            .await
2175            .unwrap();
2176        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2177            .await
2178            .unwrap();
2179
2180        // Get the context objects
2181        let contexts = context_store.update(cx, |store, _| store.context().clone());
2182        assert_eq!(contexts.len(), 3);
2183
2184        // First message with context 1
2185        let message1_id = thread.update(cx, |thread, cx| {
2186            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2187        });
2188
2189        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2190        let message2_id = thread.update(cx, |thread, cx| {
2191            thread.insert_user_message(
2192                "Message 2",
2193                vec![contexts[0].clone(), contexts[1].clone()],
2194                None,
2195                cx,
2196            )
2197        });
2198
2199        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2200        let message3_id = thread.update(cx, |thread, cx| {
2201            thread.insert_user_message(
2202                "Message 3",
2203                vec![
2204                    contexts[0].clone(),
2205                    contexts[1].clone(),
2206                    contexts[2].clone(),
2207                ],
2208                None,
2209                cx,
2210            )
2211        });
2212
2213        // Check what contexts are included in each message
2214        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2215            (
2216                thread.message(message1_id).unwrap().clone(),
2217                thread.message(message2_id).unwrap().clone(),
2218                thread.message(message3_id).unwrap().clone(),
2219            )
2220        });
2221
2222        // First message should include context 1
2223        assert!(message1.context.contains("file1.rs"));
2224
2225        // Second message should include only context 2 (not 1)
2226        assert!(!message2.context.contains("file1.rs"));
2227        assert!(message2.context.contains("file2.rs"));
2228
2229        // Third message should include only context 3 (not 1 or 2)
2230        assert!(!message3.context.contains("file1.rs"));
2231        assert!(!message3.context.contains("file2.rs"));
2232        assert!(message3.context.contains("file3.rs"));
2233
2234        // Check entire request to make sure all contexts are properly included
2235        let request = thread.read_with(cx, |thread, cx| {
2236            thread.to_completion_request(RequestKind::Chat, cx)
2237        });
2238
2239        // The request should contain all 3 messages
2240        assert_eq!(request.messages.len(), 4);
2241
2242        // Check that the contexts are properly formatted in each message
2243        assert!(request.messages[1].string_contents().contains("file1.rs"));
2244        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2245        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2246
2247        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2248        assert!(request.messages[2].string_contents().contains("file2.rs"));
2249        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2250
2251        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2252        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2253        assert!(request.messages[3].string_contents().contains("file3.rs"));
2254    }
2255
2256    #[gpui::test]
2257    async fn test_message_without_files(cx: &mut TestAppContext) {
2258        init_test_settings(cx);
2259
2260        let project = create_test_project(
2261            cx,
2262            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2263        )
2264        .await;
2265
2266        let (_, _thread_store, thread, _context_store) =
2267            setup_test_environment(cx, project.clone()).await;
2268
2269        // Insert user message without any context (empty context vector)
2270        let message_id = thread.update(cx, |thread, cx| {
2271            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2272        });
2273
2274        // Check content and context in message object
2275        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2276
2277        // Context should be empty when no files are included
2278        assert_eq!(message.role, Role::User);
2279        assert_eq!(message.segments.len(), 1);
2280        assert_eq!(
2281            message.segments[0],
2282            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2283        );
2284        assert_eq!(message.context, "");
2285
2286        // Check message in request
2287        let request = thread.read_with(cx, |thread, cx| {
2288            thread.to_completion_request(RequestKind::Chat, cx)
2289        });
2290
2291        assert_eq!(request.messages.len(), 2);
2292        assert_eq!(
2293            request.messages[1].string_contents(),
2294            "What is the best way to learn Rust?"
2295        );
2296
2297        // Add second message, also without context
2298        let message2_id = thread.update(cx, |thread, cx| {
2299            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2300        });
2301
2302        let message2 =
2303            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2304        assert_eq!(message2.context, "");
2305
2306        // Check that both messages appear in the request
2307        let request = thread.read_with(cx, |thread, cx| {
2308            thread.to_completion_request(RequestKind::Chat, cx)
2309        });
2310
2311        assert_eq!(request.messages.len(), 3);
2312        assert_eq!(
2313            request.messages[1].string_contents(),
2314            "What is the best way to learn Rust?"
2315        );
2316        assert_eq!(
2317            request.messages[2].string_contents(),
2318            "Are there any good books?"
2319        );
2320    }
2321
2322    #[gpui::test]
2323    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2324        init_test_settings(cx);
2325
2326        let project = create_test_project(
2327            cx,
2328            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2329        )
2330        .await;
2331
2332        let (_workspace, _thread_store, thread, context_store) =
2333            setup_test_environment(cx, project.clone()).await;
2334
2335        // Open buffer and add it to context
2336        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2337            .await
2338            .unwrap();
2339
2340        let context =
2341            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2342
2343        // Insert user message with the buffer as context
2344        thread.update(cx, |thread, cx| {
2345            thread.insert_user_message("Explain this code", vec![context], None, cx)
2346        });
2347
2348        // Create a request and check that it doesn't have a stale buffer warning yet
2349        let initial_request = thread.read_with(cx, |thread, cx| {
2350            thread.to_completion_request(RequestKind::Chat, cx)
2351        });
2352
2353        // Make sure we don't have a stale file warning yet
2354        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2355            msg.string_contents()
2356                .contains("These files changed since last read:")
2357        });
2358        assert!(
2359            !has_stale_warning,
2360            "Should not have stale buffer warning before buffer is modified"
2361        );
2362
2363        // Modify the buffer
2364        buffer.update(cx, |buffer, cx| {
2365            // Find a position at the end of line 1
2366            buffer.edit(
2367                [(1..1, "\n    println!(\"Added a new line\");\n")],
2368                None,
2369                cx,
2370            );
2371        });
2372
2373        // Insert another user message without context
2374        thread.update(cx, |thread, cx| {
2375            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2376        });
2377
2378        // Create a new request and check for the stale buffer warning
2379        let new_request = thread.read_with(cx, |thread, cx| {
2380            thread.to_completion_request(RequestKind::Chat, cx)
2381        });
2382
2383        // We should have a stale file warning as the last message
2384        let last_message = new_request
2385            .messages
2386            .last()
2387            .expect("Request should have messages");
2388
2389        // The last message should be the stale buffer notification
2390        assert_eq!(last_message.role, Role::User);
2391
2392        // Check the exact content of the message
2393        let expected_content = "These files changed since last read:\n- code.rs\n";
2394        assert_eq!(
2395            last_message.string_contents(),
2396            expected_content,
2397            "Last message should be exactly the stale buffer notification"
2398        );
2399    }
2400
2401    fn init_test_settings(cx: &mut TestAppContext) {
2402        cx.update(|cx| {
2403            let settings_store = SettingsStore::test(cx);
2404            cx.set_global(settings_store);
2405            language::init(cx);
2406            Project::init_settings(cx);
2407            AssistantSettings::register(cx);
2408            thread_store::init(cx);
2409            workspace::init_settings(cx);
2410            ThemeSettings::register(cx);
2411            ContextServerSettings::register(cx);
2412            EditorSettings::register(cx);
2413        });
2414    }
2415
2416    // Helper to create a test project with test files
2417    async fn create_test_project(
2418        cx: &mut TestAppContext,
2419        files: serde_json::Value,
2420    ) -> Entity<Project> {
2421        let fs = FakeFs::new(cx.executor());
2422        fs.insert_tree(path!("/test"), files).await;
2423        Project::test(fs, [path!("/test").as_ref()], cx).await
2424    }
2425
2426    async fn setup_test_environment(
2427        cx: &mut TestAppContext,
2428        project: Entity<Project>,
2429    ) -> (
2430        Entity<Workspace>,
2431        Entity<ThreadStore>,
2432        Entity<Thread>,
2433        Entity<ContextStore>,
2434    ) {
2435        let (workspace, cx) =
2436            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2437
2438        let thread_store = cx
2439            .update(|_, cx| {
2440                ThreadStore::load(
2441                    project.clone(),
2442                    cx.new(|_| ToolWorkingSet::default()),
2443                    Arc::new(PromptBuilder::new(None).unwrap()),
2444                    cx,
2445                )
2446            })
2447            .await;
2448
2449        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2450        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2451
2452        (workspace, thread_store, thread, context_store)
2453    }
2454
2455    async fn add_file_to_context(
2456        project: &Entity<Project>,
2457        context_store: &Entity<ContextStore>,
2458        path: &str,
2459        cx: &mut TestAppContext,
2460    ) -> Result<Entity<language::Buffer>> {
2461        let buffer_path = project
2462            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2463            .unwrap();
2464
2465        let buffer = project
2466            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2467            .await
2468            .unwrap();
2469
2470        context_store
2471            .update(cx, |store, cx| {
2472                store.add_file_from_buffer(buffer.clone(), cx)
2473            })
2474            .await?;
2475
2476        Ok(buffer)
2477    }
2478}