thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use agent_rules::load_worktree_rules_file;
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use fs::Fs;
  14use futures::future::Shared;
  15use futures::{FutureExt, StreamExt as _};
  16use git::repository::DiffType;
  17use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  18use language_model::{
  19    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry,
  20    LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
  21    LanguageModelToolResult, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    PaymentRequiredError, Role, StopReason, TokenUsage,
  23};
  24use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  25use project::{Project, Worktree};
  26use prompt_store::{AssistantSystemPromptContext, PromptBuilder, WorktreeInfoForSystemPrompt};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use util::{ResultExt as _, TryFutureExt as _, post_inc};
  31use uuid::Uuid;
  32
  33use crate::context::{AssistantContext, ContextId, format_context_as_string};
  34use crate::thread_store::{
  35    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  36    SerializedToolUse,
  37};
  38use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  39
  40#[derive(Debug, Clone, Copy)]
  41pub enum RequestKind {
  42    Chat,
  43    /// Used when summarizing a thread.
  44    Summarize,
  45}
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  49)]
  50pub struct ThreadId(Arc<str>);
  51
  52impl ThreadId {
  53    pub fn new() -> Self {
  54        Self(Uuid::new_v4().to_string().into())
  55    }
  56}
  57
  58impl std::fmt::Display for ThreadId {
  59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  60        write!(f, "{}", self.0)
  61    }
  62}
  63
  64impl From<&str> for ThreadId {
  65    fn from(value: &str) -> Self {
  66        Self(value.into())
  67    }
  68}
  69
  70#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  71pub struct MessageId(pub(crate) usize);
  72
  73impl MessageId {
  74    fn post_inc(&mut self) -> Self {
  75        Self(post_inc(&mut self.0))
  76    }
  77}
  78
  79/// A message in a [`Thread`].
  80#[derive(Debug, Clone)]
  81pub struct Message {
  82    pub id: MessageId,
  83    pub role: Role,
  84    pub segments: Vec<MessageSegment>,
  85    pub context: String,
  86}
  87
  88impl Message {
  89    /// Returns whether the message contains any meaningful text that should be displayed
  90    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  91    pub fn should_display_content(&self) -> bool {
  92        self.segments.iter().all(|segment| segment.should_display())
  93    }
  94
  95    pub fn push_thinking(&mut self, text: &str) {
  96        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  97            segment.push_str(text);
  98        } else {
  99            self.segments
 100                .push(MessageSegment::Thinking(text.to_string()));
 101        }
 102    }
 103
 104    pub fn push_text(&mut self, text: &str) {
 105        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 106            segment.push_str(text);
 107        } else {
 108            self.segments.push(MessageSegment::Text(text.to_string()));
 109        }
 110    }
 111
 112    pub fn to_string(&self) -> String {
 113        let mut result = String::new();
 114
 115        if !self.context.is_empty() {
 116            result.push_str(&self.context);
 117        }
 118
 119        for segment in &self.segments {
 120            match segment {
 121                MessageSegment::Text(text) => result.push_str(text),
 122                MessageSegment::Thinking(text) => {
 123                    result.push_str("<think>");
 124                    result.push_str(text);
 125                    result.push_str("</think>");
 126                }
 127            }
 128        }
 129
 130        result
 131    }
 132}
 133
 134#[derive(Debug, Clone, PartialEq, Eq)]
 135pub enum MessageSegment {
 136    Text(String),
 137    Thinking(String),
 138}
 139
 140impl MessageSegment {
 141    pub fn text_mut(&mut self) -> &mut String {
 142        match self {
 143            Self::Text(text) => text,
 144            Self::Thinking(text) => text,
 145        }
 146    }
 147
 148    pub fn should_display(&self) -> bool {
 149        // We add USING_TOOL_MARKER when making a request that includes tool uses
 150        // without non-whitespace text around them, and this can cause the model
 151        // to mimic the pattern, so we consider those segments not displayable.
 152        match self {
 153            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 154            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 155        }
 156    }
 157}
 158
 159#[derive(Debug, Clone, Serialize, Deserialize)]
 160pub struct ProjectSnapshot {
 161    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 162    pub unsaved_buffer_paths: Vec<String>,
 163    pub timestamp: DateTime<Utc>,
 164}
 165
 166#[derive(Debug, Clone, Serialize, Deserialize)]
 167pub struct WorktreeSnapshot {
 168    pub worktree_path: String,
 169    pub git_state: Option<GitState>,
 170}
 171
 172#[derive(Debug, Clone, Serialize, Deserialize)]
 173pub struct GitState {
 174    pub remote_url: Option<String>,
 175    pub head_sha: Option<String>,
 176    pub current_branch: Option<String>,
 177    pub diff: Option<String>,
 178}
 179
 180#[derive(Clone)]
 181pub struct ThreadCheckpoint {
 182    message_id: MessageId,
 183    git_checkpoint: GitStoreCheckpoint,
 184}
 185
 186#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 187pub enum ThreadFeedback {
 188    Positive,
 189    Negative,
 190}
 191
 192pub enum LastRestoreCheckpoint {
 193    Pending {
 194        message_id: MessageId,
 195    },
 196    Error {
 197        message_id: MessageId,
 198        error: String,
 199    },
 200}
 201
 202impl LastRestoreCheckpoint {
 203    pub fn message_id(&self) -> MessageId {
 204        match self {
 205            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 206            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 207        }
 208    }
 209}
 210
 211#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 212pub enum DetailedSummaryState {
 213    #[default]
 214    NotGenerated,
 215    Generating {
 216        message_id: MessageId,
 217    },
 218    Generated {
 219        text: SharedString,
 220        message_id: MessageId,
 221    },
 222}
 223
 224#[derive(Default)]
 225pub struct TotalTokenUsage {
 226    pub total: usize,
 227    pub max: usize,
 228    pub ratio: TokenUsageRatio,
 229}
 230
 231#[derive(Default, PartialEq, Eq)]
 232pub enum TokenUsageRatio {
 233    #[default]
 234    Normal,
 235    Warning,
 236    Exceeded,
 237}
 238
 239/// A thread of conversation with the LLM.
 240pub struct Thread {
 241    id: ThreadId,
 242    updated_at: DateTime<Utc>,
 243    summary: Option<SharedString>,
 244    pending_summary: Task<Option<()>>,
 245    detailed_summary_state: DetailedSummaryState,
 246    messages: Vec<Message>,
 247    next_message_id: MessageId,
 248    context: BTreeMap<ContextId, AssistantContext>,
 249    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 250    system_prompt_context: Option<AssistantSystemPromptContext>,
 251    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 252    completion_count: usize,
 253    pending_completions: Vec<PendingCompletion>,
 254    project: Entity<Project>,
 255    prompt_builder: Arc<PromptBuilder>,
 256    tools: Arc<ToolWorkingSet>,
 257    tool_use: ToolUseState,
 258    action_log: Entity<ActionLog>,
 259    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 260    pending_checkpoint: Option<ThreadCheckpoint>,
 261    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 262    cumulative_token_usage: TokenUsage,
 263    feedback: Option<ThreadFeedback>,
 264    message_feedback: HashMap<MessageId, ThreadFeedback>,
 265}
 266
 267impl Thread {
 268    pub fn new(
 269        project: Entity<Project>,
 270        tools: Arc<ToolWorkingSet>,
 271        prompt_builder: Arc<PromptBuilder>,
 272        cx: &mut Context<Self>,
 273    ) -> Self {
 274        Self {
 275            id: ThreadId::new(),
 276            updated_at: Utc::now(),
 277            summary: None,
 278            pending_summary: Task::ready(None),
 279            detailed_summary_state: DetailedSummaryState::NotGenerated,
 280            messages: Vec::new(),
 281            next_message_id: MessageId(0),
 282            context: BTreeMap::default(),
 283            context_by_message: HashMap::default(),
 284            system_prompt_context: None,
 285            checkpoints_by_message: HashMap::default(),
 286            completion_count: 0,
 287            pending_completions: Vec::new(),
 288            project: project.clone(),
 289            prompt_builder,
 290            tools: tools.clone(),
 291            last_restore_checkpoint: None,
 292            pending_checkpoint: None,
 293            tool_use: ToolUseState::new(tools.clone()),
 294            action_log: cx.new(|_| ActionLog::new(project.clone())),
 295            initial_project_snapshot: {
 296                let project_snapshot = Self::project_snapshot(project, cx);
 297                cx.foreground_executor()
 298                    .spawn(async move { Some(project_snapshot.await) })
 299                    .shared()
 300            },
 301            cumulative_token_usage: TokenUsage::default(),
 302            feedback: None,
 303            message_feedback: HashMap::default(),
 304        }
 305    }
 306
 307    pub fn deserialize(
 308        id: ThreadId,
 309        serialized: SerializedThread,
 310        project: Entity<Project>,
 311        tools: Arc<ToolWorkingSet>,
 312        prompt_builder: Arc<PromptBuilder>,
 313        cx: &mut Context<Self>,
 314    ) -> Self {
 315        let next_message_id = MessageId(
 316            serialized
 317                .messages
 318                .last()
 319                .map(|message| message.id.0 + 1)
 320                .unwrap_or(0),
 321        );
 322        let tool_use =
 323            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 324
 325        Self {
 326            id,
 327            updated_at: serialized.updated_at,
 328            summary: Some(serialized.summary),
 329            pending_summary: Task::ready(None),
 330            detailed_summary_state: serialized.detailed_summary_state,
 331            messages: serialized
 332                .messages
 333                .into_iter()
 334                .map(|message| Message {
 335                    id: message.id,
 336                    role: message.role,
 337                    segments: message
 338                        .segments
 339                        .into_iter()
 340                        .map(|segment| match segment {
 341                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 342                            SerializedMessageSegment::Thinking { text } => {
 343                                MessageSegment::Thinking(text)
 344                            }
 345                        })
 346                        .collect(),
 347                    context: message.context,
 348                })
 349                .collect(),
 350            next_message_id,
 351            context: BTreeMap::default(),
 352            context_by_message: HashMap::default(),
 353            system_prompt_context: None,
 354            checkpoints_by_message: HashMap::default(),
 355            completion_count: 0,
 356            pending_completions: Vec::new(),
 357            last_restore_checkpoint: None,
 358            pending_checkpoint: None,
 359            project: project.clone(),
 360            prompt_builder,
 361            tools,
 362            tool_use,
 363            action_log: cx.new(|_| ActionLog::new(project)),
 364            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 365            cumulative_token_usage: serialized.cumulative_token_usage,
 366            feedback: None,
 367            message_feedback: HashMap::default(),
 368        }
 369    }
 370
 371    pub fn id(&self) -> &ThreadId {
 372        &self.id
 373    }
 374
 375    pub fn is_empty(&self) -> bool {
 376        self.messages.is_empty()
 377    }
 378
 379    pub fn updated_at(&self) -> DateTime<Utc> {
 380        self.updated_at
 381    }
 382
 383    pub fn touch_updated_at(&mut self) {
 384        self.updated_at = Utc::now();
 385    }
 386
 387    pub fn summary(&self) -> Option<SharedString> {
 388        self.summary.clone()
 389    }
 390
 391    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 392
 393    pub fn summary_or_default(&self) -> SharedString {
 394        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 395    }
 396
 397    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 398        let Some(current_summary) = &self.summary else {
 399            // Don't allow setting summary until generated
 400            return;
 401        };
 402
 403        let mut new_summary = new_summary.into();
 404
 405        if new_summary.is_empty() {
 406            new_summary = Self::DEFAULT_SUMMARY;
 407        }
 408
 409        if current_summary != &new_summary {
 410            self.summary = Some(new_summary);
 411            cx.emit(ThreadEvent::SummaryChanged);
 412        }
 413    }
 414
 415    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 416        self.latest_detailed_summary()
 417            .unwrap_or_else(|| self.text().into())
 418    }
 419
 420    fn latest_detailed_summary(&self) -> Option<SharedString> {
 421        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 422            Some(text.clone())
 423        } else {
 424            None
 425        }
 426    }
 427
 428    pub fn message(&self, id: MessageId) -> Option<&Message> {
 429        self.messages.iter().find(|message| message.id == id)
 430    }
 431
 432    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 433        self.messages.iter()
 434    }
 435
 436    pub fn is_generating(&self) -> bool {
 437        !self.pending_completions.is_empty() || !self.all_tools_finished()
 438    }
 439
 440    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 441        &self.tools
 442    }
 443
 444    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 445        self.tool_use
 446            .pending_tool_uses()
 447            .into_iter()
 448            .find(|tool_use| &tool_use.id == id)
 449    }
 450
 451    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 452        self.tool_use
 453            .pending_tool_uses()
 454            .into_iter()
 455            .filter(|tool_use| tool_use.status.needs_confirmation())
 456    }
 457
 458    pub fn has_pending_tool_uses(&self) -> bool {
 459        !self.tool_use.pending_tool_uses().is_empty()
 460    }
 461
 462    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 463        self.checkpoints_by_message.get(&id).cloned()
 464    }
 465
 466    pub fn restore_checkpoint(
 467        &mut self,
 468        checkpoint: ThreadCheckpoint,
 469        cx: &mut Context<Self>,
 470    ) -> Task<Result<()>> {
 471        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 472            message_id: checkpoint.message_id,
 473        });
 474        cx.emit(ThreadEvent::CheckpointChanged);
 475        cx.notify();
 476
 477        let git_store = self.project().read(cx).git_store().clone();
 478        let restore = git_store.update(cx, |git_store, cx| {
 479            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 480        });
 481
 482        cx.spawn(async move |this, cx| {
 483            let result = restore.await;
 484            this.update(cx, |this, cx| {
 485                if let Err(err) = result.as_ref() {
 486                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 487                        message_id: checkpoint.message_id,
 488                        error: err.to_string(),
 489                    });
 490                } else {
 491                    this.truncate(checkpoint.message_id, cx);
 492                    this.last_restore_checkpoint = None;
 493                }
 494                this.pending_checkpoint = None;
 495                cx.emit(ThreadEvent::CheckpointChanged);
 496                cx.notify();
 497            })?;
 498            result
 499        })
 500    }
 501
 502    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 503        let pending_checkpoint = if self.is_generating() {
 504            return;
 505        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 506            checkpoint
 507        } else {
 508            return;
 509        };
 510
 511        let git_store = self.project.read(cx).git_store().clone();
 512        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 513        cx.spawn(async move |this, cx| match final_checkpoint.await {
 514            Ok(final_checkpoint) => {
 515                let equal = git_store
 516                    .update(cx, |store, cx| {
 517                        store.compare_checkpoints(
 518                            pending_checkpoint.git_checkpoint.clone(),
 519                            final_checkpoint.clone(),
 520                            cx,
 521                        )
 522                    })?
 523                    .await
 524                    .unwrap_or(false);
 525
 526                if equal {
 527                    git_store
 528                        .update(cx, |store, cx| {
 529                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 530                        })?
 531                        .detach();
 532                } else {
 533                    this.update(cx, |this, cx| {
 534                        this.insert_checkpoint(pending_checkpoint, cx)
 535                    })?;
 536                }
 537
 538                git_store
 539                    .update(cx, |store, cx| {
 540                        store.delete_checkpoint(final_checkpoint, cx)
 541                    })?
 542                    .detach();
 543
 544                Ok(())
 545            }
 546            Err(_) => this.update(cx, |this, cx| {
 547                this.insert_checkpoint(pending_checkpoint, cx)
 548            }),
 549        })
 550        .detach();
 551    }
 552
 553    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 554        self.checkpoints_by_message
 555            .insert(checkpoint.message_id, checkpoint);
 556        cx.emit(ThreadEvent::CheckpointChanged);
 557        cx.notify();
 558    }
 559
 560    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 561        self.last_restore_checkpoint.as_ref()
 562    }
 563
 564    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 565        let Some(message_ix) = self
 566            .messages
 567            .iter()
 568            .rposition(|message| message.id == message_id)
 569        else {
 570            return;
 571        };
 572        for deleted_message in self.messages.drain(message_ix..) {
 573            self.context_by_message.remove(&deleted_message.id);
 574            self.checkpoints_by_message.remove(&deleted_message.id);
 575        }
 576        cx.notify();
 577    }
 578
 579    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 580        self.context_by_message
 581            .get(&id)
 582            .into_iter()
 583            .flat_map(|context| {
 584                context
 585                    .iter()
 586                    .filter_map(|context_id| self.context.get(&context_id))
 587            })
 588    }
 589
 590    /// Returns whether all of the tool uses have finished running.
 591    pub fn all_tools_finished(&self) -> bool {
 592        // If the only pending tool uses left are the ones with errors, then
 593        // that means that we've finished running all of the pending tools.
 594        self.tool_use
 595            .pending_tool_uses()
 596            .iter()
 597            .all(|tool_use| tool_use.status.is_error())
 598    }
 599
 600    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 601        self.tool_use.tool_uses_for_message(id, cx)
 602    }
 603
 604    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 605        self.tool_use.tool_results_for_message(id)
 606    }
 607
 608    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 609        self.tool_use.tool_result(id)
 610    }
 611
 612    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 613        self.tool_use.message_has_tool_results(message_id)
 614    }
 615
 616    pub fn insert_user_message(
 617        &mut self,
 618        text: impl Into<String>,
 619        context: Vec<AssistantContext>,
 620        git_checkpoint: Option<GitStoreCheckpoint>,
 621        cx: &mut Context<Self>,
 622    ) -> MessageId {
 623        let text = text.into();
 624
 625        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 626
 627        // Filter out contexts that have already been included in previous messages
 628        let new_context: Vec<_> = context
 629            .into_iter()
 630            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 631            .collect();
 632
 633        if !new_context.is_empty() {
 634            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 635                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 636                    message.context = context_string;
 637                }
 638            }
 639
 640            self.action_log.update(cx, |log, cx| {
 641                // Track all buffers added as context
 642                for ctx in &new_context {
 643                    match ctx {
 644                        AssistantContext::File(file_ctx) => {
 645                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 646                        }
 647                        AssistantContext::Directory(dir_ctx) => {
 648                            for context_buffer in &dir_ctx.context_buffers {
 649                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 650                            }
 651                        }
 652                        AssistantContext::Symbol(symbol_ctx) => {
 653                            log.buffer_added_as_context(
 654                                symbol_ctx.context_symbol.buffer.clone(),
 655                                cx,
 656                            );
 657                        }
 658                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 659                    }
 660                }
 661            });
 662        }
 663
 664        let context_ids = new_context
 665            .iter()
 666            .map(|context| context.id())
 667            .collect::<Vec<_>>();
 668        self.context.extend(
 669            new_context
 670                .into_iter()
 671                .map(|context| (context.id(), context)),
 672        );
 673        self.context_by_message.insert(message_id, context_ids);
 674
 675        if let Some(git_checkpoint) = git_checkpoint {
 676            self.pending_checkpoint = Some(ThreadCheckpoint {
 677                message_id,
 678                git_checkpoint,
 679            });
 680        }
 681
 682        self.auto_capture_telemetry(cx);
 683
 684        message_id
 685    }
 686
 687    pub fn insert_message(
 688        &mut self,
 689        role: Role,
 690        segments: Vec<MessageSegment>,
 691        cx: &mut Context<Self>,
 692    ) -> MessageId {
 693        let id = self.next_message_id.post_inc();
 694        self.messages.push(Message {
 695            id,
 696            role,
 697            segments,
 698            context: String::new(),
 699        });
 700        self.touch_updated_at();
 701        cx.emit(ThreadEvent::MessageAdded(id));
 702        id
 703    }
 704
 705    pub fn edit_message(
 706        &mut self,
 707        id: MessageId,
 708        new_role: Role,
 709        new_segments: Vec<MessageSegment>,
 710        cx: &mut Context<Self>,
 711    ) -> bool {
 712        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 713            return false;
 714        };
 715        message.role = new_role;
 716        message.segments = new_segments;
 717        self.touch_updated_at();
 718        cx.emit(ThreadEvent::MessageEdited(id));
 719        true
 720    }
 721
 722    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 723        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 724            return false;
 725        };
 726        self.messages.remove(index);
 727        self.context_by_message.remove(&id);
 728        self.touch_updated_at();
 729        cx.emit(ThreadEvent::MessageDeleted(id));
 730        true
 731    }
 732
 733    /// Returns the representation of this [`Thread`] in a textual form.
 734    ///
 735    /// This is the representation we use when attaching a thread as context to another thread.
 736    pub fn text(&self) -> String {
 737        let mut text = String::new();
 738
 739        for message in &self.messages {
 740            text.push_str(match message.role {
 741                language_model::Role::User => "User:",
 742                language_model::Role::Assistant => "Assistant:",
 743                language_model::Role::System => "System:",
 744            });
 745            text.push('\n');
 746
 747            for segment in &message.segments {
 748                match segment {
 749                    MessageSegment::Text(content) => text.push_str(content),
 750                    MessageSegment::Thinking(content) => {
 751                        text.push_str(&format!("<think>{}</think>", content))
 752                    }
 753                }
 754            }
 755            text.push('\n');
 756        }
 757
 758        text
 759    }
 760
 761    /// Serializes this thread into a format for storage or telemetry.
 762    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 763        let initial_project_snapshot = self.initial_project_snapshot.clone();
 764        cx.spawn(async move |this, cx| {
 765            let initial_project_snapshot = initial_project_snapshot.await;
 766            this.read_with(cx, |this, cx| SerializedThread {
 767                version: SerializedThread::VERSION.to_string(),
 768                summary: this.summary_or_default(),
 769                updated_at: this.updated_at(),
 770                messages: this
 771                    .messages()
 772                    .map(|message| SerializedMessage {
 773                        id: message.id,
 774                        role: message.role,
 775                        segments: message
 776                            .segments
 777                            .iter()
 778                            .map(|segment| match segment {
 779                                MessageSegment::Text(text) => {
 780                                    SerializedMessageSegment::Text { text: text.clone() }
 781                                }
 782                                MessageSegment::Thinking(text) => {
 783                                    SerializedMessageSegment::Thinking { text: text.clone() }
 784                                }
 785                            })
 786                            .collect(),
 787                        tool_uses: this
 788                            .tool_uses_for_message(message.id, cx)
 789                            .into_iter()
 790                            .map(|tool_use| SerializedToolUse {
 791                                id: tool_use.id,
 792                                name: tool_use.name,
 793                                input: tool_use.input,
 794                            })
 795                            .collect(),
 796                        tool_results: this
 797                            .tool_results_for_message(message.id)
 798                            .into_iter()
 799                            .map(|tool_result| SerializedToolResult {
 800                                tool_use_id: tool_result.tool_use_id.clone(),
 801                                is_error: tool_result.is_error,
 802                                content: tool_result.content.clone(),
 803                            })
 804                            .collect(),
 805                        context: message.context.clone(),
 806                    })
 807                    .collect(),
 808                initial_project_snapshot,
 809                cumulative_token_usage: this.cumulative_token_usage.clone(),
 810                detailed_summary_state: this.detailed_summary_state.clone(),
 811            })
 812        })
 813    }
 814
 815    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 816        self.system_prompt_context = Some(context);
 817    }
 818
 819    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 820        &self.system_prompt_context
 821    }
 822
 823    pub fn load_system_prompt_context(
 824        &self,
 825        cx: &App,
 826    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 827        let project = self.project.read(cx);
 828        let tasks = project
 829            .visible_worktrees(cx)
 830            .map(|worktree| {
 831                Self::load_worktree_info_for_system_prompt(
 832                    project.fs().clone(),
 833                    worktree.read(cx),
 834                    cx,
 835                )
 836            })
 837            .collect::<Vec<_>>();
 838
 839        cx.spawn(async |_cx| {
 840            let results = futures::future::join_all(tasks).await;
 841            let mut first_err = None;
 842            let worktrees = results
 843                .into_iter()
 844                .map(|(worktree, err)| {
 845                    if first_err.is_none() && err.is_some() {
 846                        first_err = err;
 847                    }
 848                    worktree
 849                })
 850                .collect::<Vec<_>>();
 851            (AssistantSystemPromptContext::new(worktrees), first_err)
 852        })
 853    }
 854
 855    fn load_worktree_info_for_system_prompt(
 856        fs: Arc<dyn Fs>,
 857        worktree: &Worktree,
 858        cx: &App,
 859    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 860        let root_name = worktree.root_name().into();
 861        let abs_path = worktree.abs_path();
 862
 863        let rules_task = load_worktree_rules_file(fs, worktree, cx);
 864        let Some(rules_task) = rules_task else {
 865            return Task::ready((
 866                WorktreeInfoForSystemPrompt {
 867                    root_name,
 868                    abs_path,
 869                    rules_file: None,
 870                },
 871                None,
 872            ));
 873        };
 874
 875        cx.spawn(async move |_| {
 876            let (rules_file, rules_file_error) = match rules_task.await {
 877                Ok(rules_file) => (Some(rules_file), None),
 878                Err(err) => (
 879                    None,
 880                    Some(ThreadError::Message {
 881                        header: "Error loading rules file".into(),
 882                        message: format!("{err}").into(),
 883                    }),
 884                ),
 885            };
 886            let worktree_info = WorktreeInfoForSystemPrompt {
 887                root_name,
 888                abs_path,
 889                rules_file,
 890            };
 891            (worktree_info, rules_file_error)
 892        })
 893    }
 894
 895    pub fn send_to_model(
 896        &mut self,
 897        model: Arc<dyn LanguageModel>,
 898        request_kind: RequestKind,
 899        cx: &mut Context<Self>,
 900    ) {
 901        let mut request = self.to_completion_request(request_kind, cx);
 902        if model.supports_tools() {
 903            request.tools = {
 904                let mut tools = Vec::new();
 905                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 906                    LanguageModelRequestTool {
 907                        name: tool.name(),
 908                        description: tool.description(),
 909                        input_schema: tool.input_schema(model.tool_input_format()),
 910                    }
 911                }));
 912
 913                tools
 914            };
 915        }
 916
 917        self.stream_completion(request, model, cx);
 918    }
 919
 920    pub fn used_tools_since_last_user_message(&self) -> bool {
 921        for message in self.messages.iter().rev() {
 922            if self.tool_use.message_has_tool_results(message.id) {
 923                return true;
 924            } else if message.role == Role::User {
 925                return false;
 926            }
 927        }
 928
 929        false
 930    }
 931
 932    pub fn to_completion_request(
 933        &self,
 934        request_kind: RequestKind,
 935        cx: &App,
 936    ) -> LanguageModelRequest {
 937        let mut request = LanguageModelRequest {
 938            messages: vec![],
 939            tools: Vec::new(),
 940            stop: Vec::new(),
 941            temperature: None,
 942        };
 943
 944        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 945            if let Some(system_prompt) = self
 946                .prompt_builder
 947                .generate_assistant_system_prompt(system_prompt_context)
 948                .context("failed to generate assistant system prompt")
 949                .log_err()
 950            {
 951                request.messages.push(LanguageModelRequestMessage {
 952                    role: Role::System,
 953                    content: vec![MessageContent::Text(system_prompt)],
 954                    cache: true,
 955                });
 956            }
 957        } else {
 958            log::error!("system_prompt_context not set.")
 959        }
 960
 961        for message in &self.messages {
 962            let mut request_message = LanguageModelRequestMessage {
 963                role: message.role,
 964                content: Vec::new(),
 965                cache: false,
 966            };
 967
 968            match request_kind {
 969                RequestKind::Chat => {
 970                    self.tool_use
 971                        .attach_tool_results(message.id, &mut request_message);
 972                }
 973                RequestKind::Summarize => {
 974                    // We don't care about tool use during summarization.
 975                    if self.tool_use.message_has_tool_results(message.id) {
 976                        continue;
 977                    }
 978                }
 979            }
 980
 981            if !message.segments.is_empty() {
 982                request_message
 983                    .content
 984                    .push(MessageContent::Text(message.to_string()));
 985            }
 986
 987            match request_kind {
 988                RequestKind::Chat => {
 989                    self.tool_use
 990                        .attach_tool_uses(message.id, &mut request_message);
 991                }
 992                RequestKind::Summarize => {
 993                    // We don't care about tool use during summarization.
 994                }
 995            };
 996
 997            request.messages.push(request_message);
 998        }
 999
1000        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
1001        if let Some(last) = request.messages.last_mut() {
1002            last.cache = true;
1003        }
1004
1005        self.attached_tracked_files_state(&mut request.messages, cx);
1006
1007        request
1008    }
1009
1010    fn attached_tracked_files_state(
1011        &self,
1012        messages: &mut Vec<LanguageModelRequestMessage>,
1013        cx: &App,
1014    ) {
1015        const STALE_FILES_HEADER: &str = "These files changed since last read:";
1016
1017        let mut stale_message = String::new();
1018
1019        let action_log = self.action_log.read(cx);
1020
1021        for stale_file in action_log.stale_buffers(cx) {
1022            let Some(file) = stale_file.read(cx).file() else {
1023                continue;
1024            };
1025
1026            if stale_message.is_empty() {
1027                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
1028            }
1029
1030            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1031        }
1032
1033        let mut content = Vec::with_capacity(2);
1034
1035        if !stale_message.is_empty() {
1036            content.push(stale_message.into());
1037        }
1038
1039        if action_log.has_edited_files_since_project_diagnostics_check() {
1040            content.push(
1041                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1042                and fix all errors AND warnings you introduced! \
1043                DO NOT mention you're going to do this until you're done."
1044                    .into(),
1045            );
1046        }
1047
1048        if !content.is_empty() {
1049            let context_message = LanguageModelRequestMessage {
1050                role: Role::User,
1051                content,
1052                cache: false,
1053            };
1054
1055            messages.push(context_message);
1056        }
1057    }
1058
1059    pub fn stream_completion(
1060        &mut self,
1061        request: LanguageModelRequest,
1062        model: Arc<dyn LanguageModel>,
1063        cx: &mut Context<Self>,
1064    ) {
1065        let pending_completion_id = post_inc(&mut self.completion_count);
1066
1067        let task = cx.spawn(async move |thread, cx| {
1068            let stream = model.stream_completion(request, &cx);
1069            let initial_token_usage =
1070                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1071            let stream_completion = async {
1072                let mut events = stream.await?;
1073                let mut stop_reason = StopReason::EndTurn;
1074                let mut current_token_usage = TokenUsage::default();
1075
1076                while let Some(event) = events.next().await {
1077                    let event = event?;
1078
1079                    thread.update(cx, |thread, cx| {
1080                        match event {
1081                            LanguageModelCompletionEvent::StartMessage { .. } => {
1082                                thread.insert_message(
1083                                    Role::Assistant,
1084                                    vec![MessageSegment::Text(String::new())],
1085                                    cx,
1086                                );
1087                            }
1088                            LanguageModelCompletionEvent::Stop(reason) => {
1089                                stop_reason = reason;
1090                            }
1091                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1092                                thread.cumulative_token_usage =
1093                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1094                                        - current_token_usage.clone();
1095                                current_token_usage = token_usage;
1096                            }
1097                            LanguageModelCompletionEvent::Text(chunk) => {
1098                                if let Some(last_message) = thread.messages.last_mut() {
1099                                    if last_message.role == Role::Assistant {
1100                                        last_message.push_text(&chunk);
1101                                        cx.emit(ThreadEvent::StreamedAssistantText(
1102                                            last_message.id,
1103                                            chunk,
1104                                        ));
1105                                    } else {
1106                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1107                                        // of a new Assistant response.
1108                                        //
1109                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1110                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1111                                        thread.insert_message(
1112                                            Role::Assistant,
1113                                            vec![MessageSegment::Text(chunk.to_string())],
1114                                            cx,
1115                                        );
1116                                    };
1117                                }
1118                            }
1119                            LanguageModelCompletionEvent::Thinking(chunk) => {
1120                                if let Some(last_message) = thread.messages.last_mut() {
1121                                    if last_message.role == Role::Assistant {
1122                                        last_message.push_thinking(&chunk);
1123                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1124                                            last_message.id,
1125                                            chunk,
1126                                        ));
1127                                    } else {
1128                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1129                                        // of a new Assistant response.
1130                                        //
1131                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1132                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1133                                        thread.insert_message(
1134                                            Role::Assistant,
1135                                            vec![MessageSegment::Thinking(chunk.to_string())],
1136                                            cx,
1137                                        );
1138                                    };
1139                                }
1140                            }
1141                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1142                                let last_assistant_message_id = thread
1143                                    .messages
1144                                    .iter_mut()
1145                                    .rfind(|message| message.role == Role::Assistant)
1146                                    .map(|message| message.id)
1147                                    .unwrap_or_else(|| {
1148                                        thread.insert_message(Role::Assistant, vec![], cx)
1149                                    });
1150
1151                                thread.tool_use.request_tool_use(
1152                                    last_assistant_message_id,
1153                                    tool_use,
1154                                    cx,
1155                                );
1156                            }
1157                        }
1158
1159                        thread.touch_updated_at();
1160                        cx.emit(ThreadEvent::StreamedCompletion);
1161                        cx.notify();
1162
1163                        thread.auto_capture_telemetry(cx);
1164                    })?;
1165
1166                    smol::future::yield_now().await;
1167                }
1168
1169                thread.update(cx, |thread, cx| {
1170                    thread
1171                        .pending_completions
1172                        .retain(|completion| completion.id != pending_completion_id);
1173
1174                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1175                        thread.summarize(cx);
1176                    }
1177                })?;
1178
1179                anyhow::Ok(stop_reason)
1180            };
1181
1182            let result = stream_completion.await;
1183
1184            thread
1185                .update(cx, |thread, cx| {
1186                    thread.finalize_pending_checkpoint(cx);
1187                    match result.as_ref() {
1188                        Ok(stop_reason) => match stop_reason {
1189                            StopReason::ToolUse => {
1190                                let tool_uses = thread.use_pending_tools(cx);
1191                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1192                            }
1193                            StopReason::EndTurn => {}
1194                            StopReason::MaxTokens => {}
1195                        },
1196                        Err(error) => {
1197                            if error.is::<PaymentRequiredError>() {
1198                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1199                            } else if error.is::<MaxMonthlySpendReachedError>() {
1200                                cx.emit(ThreadEvent::ShowError(
1201                                    ThreadError::MaxMonthlySpendReached,
1202                                ));
1203                            } else {
1204                                let error_message = error
1205                                    .chain()
1206                                    .map(|err| err.to_string())
1207                                    .collect::<Vec<_>>()
1208                                    .join("\n");
1209                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1210                                    header: "Error interacting with language model".into(),
1211                                    message: SharedString::from(error_message.clone()),
1212                                }));
1213                            }
1214
1215                            thread.cancel_last_completion(cx);
1216                        }
1217                    }
1218                    cx.emit(ThreadEvent::DoneStreaming);
1219
1220                    thread.auto_capture_telemetry(cx);
1221
1222                    if let Ok(initial_usage) = initial_token_usage {
1223                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1224
1225                        telemetry::event!(
1226                            "Assistant Thread Completion",
1227                            thread_id = thread.id().to_string(),
1228                            model = model.telemetry_id(),
1229                            model_provider = model.provider_id().to_string(),
1230                            input_tokens = usage.input_tokens,
1231                            output_tokens = usage.output_tokens,
1232                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1233                            cache_read_input_tokens = usage.cache_read_input_tokens,
1234                        );
1235                    }
1236                })
1237                .ok();
1238        });
1239
1240        self.pending_completions.push(PendingCompletion {
1241            id: pending_completion_id,
1242            _task: task,
1243        });
1244    }
1245
1246    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1247        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1248            return;
1249        };
1250
1251        if !model.provider.is_authenticated(cx) {
1252            return;
1253        }
1254
1255        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1256        request.messages.push(LanguageModelRequestMessage {
1257            role: Role::User,
1258            content: vec![
1259                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1260                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1261                 If the conversation is about a specific subject, include it in the title. \
1262                 Be descriptive. DO NOT speak in the first person."
1263                    .into(),
1264            ],
1265            cache: false,
1266        });
1267
1268        self.pending_summary = cx.spawn(async move |this, cx| {
1269            async move {
1270                let stream = model.model.stream_completion_text(request, &cx);
1271                let mut messages = stream.await?;
1272
1273                let mut new_summary = String::new();
1274                while let Some(message) = messages.stream.next().await {
1275                    let text = message?;
1276                    let mut lines = text.lines();
1277                    new_summary.extend(lines.next());
1278
1279                    // Stop if the LLM generated multiple lines.
1280                    if lines.next().is_some() {
1281                        break;
1282                    }
1283                }
1284
1285                this.update(cx, |this, cx| {
1286                    if !new_summary.is_empty() {
1287                        this.summary = Some(new_summary.into());
1288                    }
1289
1290                    cx.emit(ThreadEvent::SummaryGenerated);
1291                })?;
1292
1293                anyhow::Ok(())
1294            }
1295            .log_err()
1296            .await
1297        });
1298    }
1299
1300    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1301        let last_message_id = self.messages.last().map(|message| message.id)?;
1302
1303        match &self.detailed_summary_state {
1304            DetailedSummaryState::Generating { message_id, .. }
1305            | DetailedSummaryState::Generated { message_id, .. }
1306                if *message_id == last_message_id =>
1307            {
1308                // Already up-to-date
1309                return None;
1310            }
1311            _ => {}
1312        }
1313
1314        let ConfiguredModel { model, provider } =
1315            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1316
1317        if !provider.is_authenticated(cx) {
1318            return None;
1319        }
1320
1321        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1322
1323        request.messages.push(LanguageModelRequestMessage {
1324            role: Role::User,
1325            content: vec![
1326                "Generate a detailed summary of this conversation. Include:\n\
1327                1. A brief overview of what was discussed\n\
1328                2. Key facts or information discovered\n\
1329                3. Outcomes or conclusions reached\n\
1330                4. Any action items or next steps if any\n\
1331                Format it in Markdown with headings and bullet points."
1332                    .into(),
1333            ],
1334            cache: false,
1335        });
1336
1337        let task = cx.spawn(async move |thread, cx| {
1338            let stream = model.stream_completion_text(request, &cx);
1339            let Some(mut messages) = stream.await.log_err() else {
1340                thread
1341                    .update(cx, |this, _cx| {
1342                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1343                    })
1344                    .log_err();
1345
1346                return;
1347            };
1348
1349            let mut new_detailed_summary = String::new();
1350
1351            while let Some(chunk) = messages.stream.next().await {
1352                if let Some(chunk) = chunk.log_err() {
1353                    new_detailed_summary.push_str(&chunk);
1354                }
1355            }
1356
1357            thread
1358                .update(cx, |this, _cx| {
1359                    this.detailed_summary_state = DetailedSummaryState::Generated {
1360                        text: new_detailed_summary.into(),
1361                        message_id: last_message_id,
1362                    };
1363                })
1364                .log_err();
1365        });
1366
1367        self.detailed_summary_state = DetailedSummaryState::Generating {
1368            message_id: last_message_id,
1369        };
1370
1371        Some(task)
1372    }
1373
1374    pub fn is_generating_detailed_summary(&self) -> bool {
1375        matches!(
1376            self.detailed_summary_state,
1377            DetailedSummaryState::Generating { .. }
1378        )
1379    }
1380
1381    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1382        self.auto_capture_telemetry(cx);
1383        let request = self.to_completion_request(RequestKind::Chat, cx);
1384        let messages = Arc::new(request.messages);
1385        let pending_tool_uses = self
1386            .tool_use
1387            .pending_tool_uses()
1388            .into_iter()
1389            .filter(|tool_use| tool_use.status.is_idle())
1390            .cloned()
1391            .collect::<Vec<_>>();
1392
1393        for tool_use in pending_tool_uses.iter() {
1394            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1395                if tool.needs_confirmation(&tool_use.input, cx)
1396                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1397                {
1398                    self.tool_use.confirm_tool_use(
1399                        tool_use.id.clone(),
1400                        tool_use.ui_text.clone(),
1401                        tool_use.input.clone(),
1402                        messages.clone(),
1403                        tool,
1404                    );
1405                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1406                } else {
1407                    self.run_tool(
1408                        tool_use.id.clone(),
1409                        tool_use.ui_text.clone(),
1410                        tool_use.input.clone(),
1411                        &messages,
1412                        tool,
1413                        cx,
1414                    );
1415                }
1416            }
1417        }
1418
1419        pending_tool_uses
1420    }
1421
1422    pub fn run_tool(
1423        &mut self,
1424        tool_use_id: LanguageModelToolUseId,
1425        ui_text: impl Into<SharedString>,
1426        input: serde_json::Value,
1427        messages: &[LanguageModelRequestMessage],
1428        tool: Arc<dyn Tool>,
1429        cx: &mut Context<Thread>,
1430    ) {
1431        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1432        self.tool_use
1433            .run_pending_tool(tool_use_id, ui_text.into(), task);
1434    }
1435
1436    fn spawn_tool_use(
1437        &mut self,
1438        tool_use_id: LanguageModelToolUseId,
1439        messages: &[LanguageModelRequestMessage],
1440        input: serde_json::Value,
1441        tool: Arc<dyn Tool>,
1442        cx: &mut Context<Thread>,
1443    ) -> Task<()> {
1444        let tool_name: Arc<str> = tool.name().into();
1445
1446        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1447            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1448        } else {
1449            tool.run(
1450                input,
1451                messages,
1452                self.project.clone(),
1453                self.action_log.clone(),
1454                cx,
1455            )
1456        };
1457
1458        cx.spawn({
1459            async move |thread: WeakEntity<Thread>, cx| {
1460                let output = run_tool.await;
1461
1462                thread
1463                    .update(cx, |thread, cx| {
1464                        let pending_tool_use = thread.tool_use.insert_tool_output(
1465                            tool_use_id.clone(),
1466                            tool_name,
1467                            output,
1468                            cx,
1469                        );
1470                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1471                    })
1472                    .ok();
1473            }
1474        })
1475    }
1476
1477    fn tool_finished(
1478        &mut self,
1479        tool_use_id: LanguageModelToolUseId,
1480        pending_tool_use: Option<PendingToolUse>,
1481        canceled: bool,
1482        cx: &mut Context<Self>,
1483    ) {
1484        if self.all_tools_finished() {
1485            let model_registry = LanguageModelRegistry::read_global(cx);
1486            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1487                self.attach_tool_results(cx);
1488                if !canceled {
1489                    self.send_to_model(model, RequestKind::Chat, cx);
1490                }
1491            }
1492        }
1493
1494        cx.emit(ThreadEvent::ToolFinished {
1495            tool_use_id,
1496            pending_tool_use,
1497        });
1498    }
1499
1500    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1501        // Insert a user message to contain the tool results.
1502        self.insert_user_message(
1503            // TODO: Sending up a user message without any content results in the model sending back
1504            // responses that also don't have any content. We currently don't handle this case well,
1505            // so for now we provide some text to keep the model on track.
1506            "Here are the tool results.",
1507            Vec::new(),
1508            None,
1509            cx,
1510        );
1511    }
1512
1513    /// Cancels the last pending completion, if there are any pending.
1514    ///
1515    /// Returns whether a completion was canceled.
1516    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1517        let canceled = if self.pending_completions.pop().is_some() {
1518            true
1519        } else {
1520            let mut canceled = false;
1521            for pending_tool_use in self.tool_use.cancel_pending() {
1522                canceled = true;
1523                self.tool_finished(
1524                    pending_tool_use.id.clone(),
1525                    Some(pending_tool_use),
1526                    true,
1527                    cx,
1528                );
1529            }
1530            canceled
1531        };
1532        self.finalize_pending_checkpoint(cx);
1533        canceled
1534    }
1535
1536    pub fn feedback(&self) -> Option<ThreadFeedback> {
1537        self.feedback
1538    }
1539
1540    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1541        self.message_feedback.get(&message_id).copied()
1542    }
1543
1544    pub fn report_message_feedback(
1545        &mut self,
1546        message_id: MessageId,
1547        feedback: ThreadFeedback,
1548        cx: &mut Context<Self>,
1549    ) -> Task<Result<()>> {
1550        if self.message_feedback.get(&message_id) == Some(&feedback) {
1551            return Task::ready(Ok(()));
1552        }
1553
1554        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1555        let serialized_thread = self.serialize(cx);
1556        let thread_id = self.id().clone();
1557        let client = self.project.read(cx).client();
1558
1559        let enabled_tool_names: Vec<String> = self
1560            .tools()
1561            .enabled_tools(cx)
1562            .iter()
1563            .map(|tool| tool.name().to_string())
1564            .collect();
1565
1566        self.message_feedback.insert(message_id, feedback);
1567
1568        cx.notify();
1569
1570        let message_content = self
1571            .message(message_id)
1572            .map(|msg| msg.to_string())
1573            .unwrap_or_default();
1574
1575        cx.background_spawn(async move {
1576            let final_project_snapshot = final_project_snapshot.await;
1577            let serialized_thread = serialized_thread.await?;
1578            let thread_data =
1579                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1580
1581            let rating = match feedback {
1582                ThreadFeedback::Positive => "positive",
1583                ThreadFeedback::Negative => "negative",
1584            };
1585            telemetry::event!(
1586                "Assistant Thread Rated",
1587                rating,
1588                thread_id,
1589                enabled_tool_names,
1590                message_id = message_id.0,
1591                message_content,
1592                thread_data,
1593                final_project_snapshot
1594            );
1595            client.telemetry().flush_events();
1596
1597            Ok(())
1598        })
1599    }
1600
1601    pub fn report_feedback(
1602        &mut self,
1603        feedback: ThreadFeedback,
1604        cx: &mut Context<Self>,
1605    ) -> Task<Result<()>> {
1606        let last_assistant_message_id = self
1607            .messages
1608            .iter()
1609            .rev()
1610            .find(|msg| msg.role == Role::Assistant)
1611            .map(|msg| msg.id);
1612
1613        if let Some(message_id) = last_assistant_message_id {
1614            self.report_message_feedback(message_id, feedback, cx)
1615        } else {
1616            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1617            let serialized_thread = self.serialize(cx);
1618            let thread_id = self.id().clone();
1619            let client = self.project.read(cx).client();
1620            self.feedback = Some(feedback);
1621            cx.notify();
1622
1623            cx.background_spawn(async move {
1624                let final_project_snapshot = final_project_snapshot.await;
1625                let serialized_thread = serialized_thread.await?;
1626                let thread_data = serde_json::to_value(serialized_thread)
1627                    .unwrap_or_else(|_| serde_json::Value::Null);
1628
1629                let rating = match feedback {
1630                    ThreadFeedback::Positive => "positive",
1631                    ThreadFeedback::Negative => "negative",
1632                };
1633                telemetry::event!(
1634                    "Assistant Thread Rated",
1635                    rating,
1636                    thread_id,
1637                    thread_data,
1638                    final_project_snapshot
1639                );
1640                client.telemetry().flush_events();
1641
1642                Ok(())
1643            })
1644        }
1645    }
1646
1647    /// Create a snapshot of the current project state including git information and unsaved buffers.
1648    fn project_snapshot(
1649        project: Entity<Project>,
1650        cx: &mut Context<Self>,
1651    ) -> Task<Arc<ProjectSnapshot>> {
1652        let git_store = project.read(cx).git_store().clone();
1653        let worktree_snapshots: Vec<_> = project
1654            .read(cx)
1655            .visible_worktrees(cx)
1656            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1657            .collect();
1658
1659        cx.spawn(async move |_, cx| {
1660            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1661
1662            let mut unsaved_buffers = Vec::new();
1663            cx.update(|app_cx| {
1664                let buffer_store = project.read(app_cx).buffer_store();
1665                for buffer_handle in buffer_store.read(app_cx).buffers() {
1666                    let buffer = buffer_handle.read(app_cx);
1667                    if buffer.is_dirty() {
1668                        if let Some(file) = buffer.file() {
1669                            let path = file.path().to_string_lossy().to_string();
1670                            unsaved_buffers.push(path);
1671                        }
1672                    }
1673                }
1674            })
1675            .ok();
1676
1677            Arc::new(ProjectSnapshot {
1678                worktree_snapshots,
1679                unsaved_buffer_paths: unsaved_buffers,
1680                timestamp: Utc::now(),
1681            })
1682        })
1683    }
1684
1685    fn worktree_snapshot(
1686        worktree: Entity<project::Worktree>,
1687        git_store: Entity<GitStore>,
1688        cx: &App,
1689    ) -> Task<WorktreeSnapshot> {
1690        cx.spawn(async move |cx| {
1691            // Get worktree path and snapshot
1692            let worktree_info = cx.update(|app_cx| {
1693                let worktree = worktree.read(app_cx);
1694                let path = worktree.abs_path().to_string_lossy().to_string();
1695                let snapshot = worktree.snapshot();
1696                (path, snapshot)
1697            });
1698
1699            let Ok((worktree_path, _snapshot)) = worktree_info else {
1700                return WorktreeSnapshot {
1701                    worktree_path: String::new(),
1702                    git_state: None,
1703                };
1704            };
1705
1706            let git_state = git_store
1707                .update(cx, |git_store, cx| {
1708                    git_store
1709                        .repositories()
1710                        .values()
1711                        .find(|repo| {
1712                            repo.read(cx)
1713                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1714                                .is_some()
1715                        })
1716                        .cloned()
1717                })
1718                .ok()
1719                .flatten()
1720                .map(|repo| {
1721                    repo.update(cx, |repo, _| {
1722                        let current_branch =
1723                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1724                        repo.send_job(None, |state, _| async move {
1725                            let RepositoryState::Local { backend, .. } = state else {
1726                                return GitState {
1727                                    remote_url: None,
1728                                    head_sha: None,
1729                                    current_branch,
1730                                    diff: None,
1731                                };
1732                            };
1733
1734                            let remote_url = backend.remote_url("origin");
1735                            let head_sha = backend.head_sha();
1736                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1737
1738                            GitState {
1739                                remote_url,
1740                                head_sha,
1741                                current_branch,
1742                                diff,
1743                            }
1744                        })
1745                    })
1746                });
1747
1748            let git_state = match git_state {
1749                Some(git_state) => match git_state.ok() {
1750                    Some(git_state) => git_state.await.ok(),
1751                    None => None,
1752                },
1753                None => None,
1754            };
1755
1756            WorktreeSnapshot {
1757                worktree_path,
1758                git_state,
1759            }
1760        })
1761    }
1762
1763    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1764        let mut markdown = Vec::new();
1765
1766        if let Some(summary) = self.summary() {
1767            writeln!(markdown, "# {summary}\n")?;
1768        };
1769
1770        for message in self.messages() {
1771            writeln!(
1772                markdown,
1773                "## {role}\n",
1774                role = match message.role {
1775                    Role::User => "User",
1776                    Role::Assistant => "Assistant",
1777                    Role::System => "System",
1778                }
1779            )?;
1780
1781            if !message.context.is_empty() {
1782                writeln!(markdown, "{}", message.context)?;
1783            }
1784
1785            for segment in &message.segments {
1786                match segment {
1787                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1788                    MessageSegment::Thinking(text) => {
1789                        writeln!(markdown, "<think>{}</think>\n", text)?
1790                    }
1791                }
1792            }
1793
1794            for tool_use in self.tool_uses_for_message(message.id, cx) {
1795                writeln!(
1796                    markdown,
1797                    "**Use Tool: {} ({})**",
1798                    tool_use.name, tool_use.id
1799                )?;
1800                writeln!(markdown, "```json")?;
1801                writeln!(
1802                    markdown,
1803                    "{}",
1804                    serde_json::to_string_pretty(&tool_use.input)?
1805                )?;
1806                writeln!(markdown, "```")?;
1807            }
1808
1809            for tool_result in self.tool_results_for_message(message.id) {
1810                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1811                if tool_result.is_error {
1812                    write!(markdown, " (Error)")?;
1813                }
1814
1815                writeln!(markdown, "**\n")?;
1816                writeln!(markdown, "{}", tool_result.content)?;
1817            }
1818        }
1819
1820        Ok(String::from_utf8_lossy(&markdown).to_string())
1821    }
1822
1823    pub fn keep_edits_in_range(
1824        &mut self,
1825        buffer: Entity<language::Buffer>,
1826        buffer_range: Range<language::Anchor>,
1827        cx: &mut Context<Self>,
1828    ) {
1829        self.action_log.update(cx, |action_log, cx| {
1830            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1831        });
1832    }
1833
1834    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1835        self.action_log
1836            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1837    }
1838
1839    pub fn reject_edits_in_range(
1840        &mut self,
1841        buffer: Entity<language::Buffer>,
1842        buffer_range: Range<language::Anchor>,
1843        cx: &mut Context<Self>,
1844    ) -> Task<Result<()>> {
1845        self.action_log.update(cx, |action_log, cx| {
1846            action_log.reject_edits_in_range(buffer, buffer_range, cx)
1847        })
1848    }
1849
1850    pub fn action_log(&self) -> &Entity<ActionLog> {
1851        &self.action_log
1852    }
1853
1854    pub fn project(&self) -> &Entity<Project> {
1855        &self.project
1856    }
1857
1858    pub fn cumulative_token_usage(&self) -> TokenUsage {
1859        self.cumulative_token_usage.clone()
1860    }
1861
1862    pub fn auto_capture_telemetry(&self, cx: &mut Context<Self>) {
1863        static mut LAST_CAPTURE: Option<std::time::Instant> = None;
1864        let now = std::time::Instant::now();
1865        let should_check = unsafe {
1866            if let Some(last) = LAST_CAPTURE {
1867                if now.duration_since(last).as_secs() < 10 {
1868                    return;
1869                }
1870            }
1871            LAST_CAPTURE = Some(now);
1872            true
1873        };
1874
1875        if !should_check {
1876            return;
1877        }
1878
1879        let feature_flag_enabled = cx.has_flag::<feature_flags::ThreadAutoCapture>();
1880
1881        if cfg!(debug_assertions) {
1882            if !feature_flag_enabled {
1883                return;
1884            }
1885        }
1886
1887        let thread_id = self.id().clone();
1888
1889        let github_handle = self
1890            .project
1891            .read(cx)
1892            .user_store()
1893            .read(cx)
1894            .current_user()
1895            .map(|user| user.github_login.clone());
1896
1897        let client = self.project.read(cx).client().clone();
1898
1899        let serialized_thread = self.serialize(cx);
1900
1901        cx.foreground_executor()
1902            .spawn(async move {
1903                if let Ok(serialized_thread) = serialized_thread.await {
1904                    let thread_data = serde_json::to_value(serialized_thread)
1905                        .unwrap_or_else(|_| serde_json::Value::Null);
1906
1907                    telemetry::event!(
1908                        "Agent Thread AutoCaptured",
1909                        thread_id = thread_id.to_string(),
1910                        thread_data = thread_data,
1911                        auto_capture_reason = "tracked_user",
1912                        github_handle = github_handle
1913                    );
1914
1915                    client.telemetry().flush_events();
1916                }
1917            })
1918            .detach();
1919    }
1920
1921    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1922        let model_registry = LanguageModelRegistry::read_global(cx);
1923        let Some(model) = model_registry.default_model() else {
1924            return TotalTokenUsage::default();
1925        };
1926
1927        let max = model.model.max_token_count();
1928
1929        #[cfg(debug_assertions)]
1930        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1931            .unwrap_or("0.8".to_string())
1932            .parse()
1933            .unwrap();
1934        #[cfg(not(debug_assertions))]
1935        let warning_threshold: f32 = 0.8;
1936
1937        let total = self.cumulative_token_usage.total_tokens() as usize;
1938
1939        let ratio = if total >= max {
1940            TokenUsageRatio::Exceeded
1941        } else if total as f32 / max as f32 >= warning_threshold {
1942            TokenUsageRatio::Warning
1943        } else {
1944            TokenUsageRatio::Normal
1945        };
1946
1947        TotalTokenUsage { total, max, ratio }
1948    }
1949
1950    pub fn deny_tool_use(
1951        &mut self,
1952        tool_use_id: LanguageModelToolUseId,
1953        tool_name: Arc<str>,
1954        cx: &mut Context<Self>,
1955    ) {
1956        let err = Err(anyhow::anyhow!(
1957            "Permission to run tool action denied by user"
1958        ));
1959
1960        self.tool_use
1961            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1962        self.tool_finished(tool_use_id.clone(), None, true, cx);
1963    }
1964}
1965
1966#[derive(Debug, Clone)]
1967pub enum ThreadError {
1968    PaymentRequired,
1969    MaxMonthlySpendReached,
1970    Message {
1971        header: SharedString,
1972        message: SharedString,
1973    },
1974}
1975
1976#[derive(Debug, Clone)]
1977pub enum ThreadEvent {
1978    ShowError(ThreadError),
1979    StreamedCompletion,
1980    StreamedAssistantText(MessageId, String),
1981    StreamedAssistantThinking(MessageId, String),
1982    DoneStreaming,
1983    MessageAdded(MessageId),
1984    MessageEdited(MessageId),
1985    MessageDeleted(MessageId),
1986    SummaryGenerated,
1987    SummaryChanged,
1988    UsePendingTools {
1989        tool_uses: Vec<PendingToolUse>,
1990    },
1991    ToolFinished {
1992        #[allow(unused)]
1993        tool_use_id: LanguageModelToolUseId,
1994        /// The pending tool use that corresponds to this tool.
1995        pending_tool_use: Option<PendingToolUse>,
1996    },
1997    CheckpointChanged,
1998    ToolConfirmationNeeded,
1999}
2000
2001impl EventEmitter<ThreadEvent> for Thread {}
2002
2003struct PendingCompletion {
2004    id: usize,
2005    _task: Task<()>,
2006}
2007
2008#[cfg(test)]
2009mod tests {
2010    use super::*;
2011    use crate::{ThreadStore, context_store::ContextStore, thread_store};
2012    use assistant_settings::AssistantSettings;
2013    use context_server::ContextServerSettings;
2014    use editor::EditorSettings;
2015    use gpui::TestAppContext;
2016    use project::{FakeFs, Project};
2017    use prompt_store::PromptBuilder;
2018    use serde_json::json;
2019    use settings::{Settings, SettingsStore};
2020    use std::sync::Arc;
2021    use theme::ThemeSettings;
2022    use util::path;
2023    use workspace::Workspace;
2024
2025    #[gpui::test]
2026    async fn test_message_with_context(cx: &mut TestAppContext) {
2027        init_test_settings(cx);
2028
2029        let project = create_test_project(
2030            cx,
2031            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2032        )
2033        .await;
2034
2035        let (_workspace, _thread_store, thread, context_store) =
2036            setup_test_environment(cx, project.clone()).await;
2037
2038        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2039            .await
2040            .unwrap();
2041
2042        let context =
2043            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2044
2045        // Insert user message with context
2046        let message_id = thread.update(cx, |thread, cx| {
2047            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2048        });
2049
2050        // Check content and context in message object
2051        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2052
2053        // Use different path format strings based on platform for the test
2054        #[cfg(windows)]
2055        let path_part = r"test\code.rs";
2056        #[cfg(not(windows))]
2057        let path_part = "test/code.rs";
2058
2059        let expected_context = format!(
2060            r#"
2061<context>
2062The following items were attached by the user. You don't need to use other tools to read them.
2063
2064<files>
2065```rs {path_part}
2066fn main() {{
2067    println!("Hello, world!");
2068}}
2069```
2070</files>
2071</context>
2072"#
2073        );
2074
2075        assert_eq!(message.role, Role::User);
2076        assert_eq!(message.segments.len(), 1);
2077        assert_eq!(
2078            message.segments[0],
2079            MessageSegment::Text("Please explain this code".to_string())
2080        );
2081        assert_eq!(message.context, expected_context);
2082
2083        // Check message in request
2084        let request = thread.read_with(cx, |thread, cx| {
2085            thread.to_completion_request(RequestKind::Chat, cx)
2086        });
2087
2088        assert_eq!(request.messages.len(), 1);
2089        let expected_full_message = format!("{}Please explain this code", expected_context);
2090        assert_eq!(request.messages[0].string_contents(), expected_full_message);
2091    }
2092
2093    #[gpui::test]
2094    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2095        init_test_settings(cx);
2096
2097        let project = create_test_project(
2098            cx,
2099            json!({
2100                "file1.rs": "fn function1() {}\n",
2101                "file2.rs": "fn function2() {}\n",
2102                "file3.rs": "fn function3() {}\n",
2103            }),
2104        )
2105        .await;
2106
2107        let (_, _thread_store, thread, context_store) =
2108            setup_test_environment(cx, project.clone()).await;
2109
2110        // Open files individually
2111        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2112            .await
2113            .unwrap();
2114        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2115            .await
2116            .unwrap();
2117        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2118            .await
2119            .unwrap();
2120
2121        // Get the context objects
2122        let contexts = context_store.update(cx, |store, _| store.context().clone());
2123        assert_eq!(contexts.len(), 3);
2124
2125        // First message with context 1
2126        let message1_id = thread.update(cx, |thread, cx| {
2127            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2128        });
2129
2130        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2131        let message2_id = thread.update(cx, |thread, cx| {
2132            thread.insert_user_message(
2133                "Message 2",
2134                vec![contexts[0].clone(), contexts[1].clone()],
2135                None,
2136                cx,
2137            )
2138        });
2139
2140        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2141        let message3_id = thread.update(cx, |thread, cx| {
2142            thread.insert_user_message(
2143                "Message 3",
2144                vec![
2145                    contexts[0].clone(),
2146                    contexts[1].clone(),
2147                    contexts[2].clone(),
2148                ],
2149                None,
2150                cx,
2151            )
2152        });
2153
2154        // Check what contexts are included in each message
2155        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2156            (
2157                thread.message(message1_id).unwrap().clone(),
2158                thread.message(message2_id).unwrap().clone(),
2159                thread.message(message3_id).unwrap().clone(),
2160            )
2161        });
2162
2163        // First message should include context 1
2164        assert!(message1.context.contains("file1.rs"));
2165
2166        // Second message should include only context 2 (not 1)
2167        assert!(!message2.context.contains("file1.rs"));
2168        assert!(message2.context.contains("file2.rs"));
2169
2170        // Third message should include only context 3 (not 1 or 2)
2171        assert!(!message3.context.contains("file1.rs"));
2172        assert!(!message3.context.contains("file2.rs"));
2173        assert!(message3.context.contains("file3.rs"));
2174
2175        // Check entire request to make sure all contexts are properly included
2176        let request = thread.read_with(cx, |thread, cx| {
2177            thread.to_completion_request(RequestKind::Chat, cx)
2178        });
2179
2180        // The request should contain all 3 messages
2181        assert_eq!(request.messages.len(), 3);
2182
2183        // Check that the contexts are properly formatted in each message
2184        assert!(request.messages[0].string_contents().contains("file1.rs"));
2185        assert!(!request.messages[0].string_contents().contains("file2.rs"));
2186        assert!(!request.messages[0].string_contents().contains("file3.rs"));
2187
2188        assert!(!request.messages[1].string_contents().contains("file1.rs"));
2189        assert!(request.messages[1].string_contents().contains("file2.rs"));
2190        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2191
2192        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2193        assert!(!request.messages[2].string_contents().contains("file2.rs"));
2194        assert!(request.messages[2].string_contents().contains("file3.rs"));
2195    }
2196
2197    #[gpui::test]
2198    async fn test_message_without_files(cx: &mut TestAppContext) {
2199        init_test_settings(cx);
2200
2201        let project = create_test_project(
2202            cx,
2203            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2204        )
2205        .await;
2206
2207        let (_, _thread_store, thread, _context_store) =
2208            setup_test_environment(cx, project.clone()).await;
2209
2210        // Insert user message without any context (empty context vector)
2211        let message_id = thread.update(cx, |thread, cx| {
2212            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2213        });
2214
2215        // Check content and context in message object
2216        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2217
2218        // Context should be empty when no files are included
2219        assert_eq!(message.role, Role::User);
2220        assert_eq!(message.segments.len(), 1);
2221        assert_eq!(
2222            message.segments[0],
2223            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2224        );
2225        assert_eq!(message.context, "");
2226
2227        // Check message in request
2228        let request = thread.read_with(cx, |thread, cx| {
2229            thread.to_completion_request(RequestKind::Chat, cx)
2230        });
2231
2232        assert_eq!(request.messages.len(), 1);
2233        assert_eq!(
2234            request.messages[0].string_contents(),
2235            "What is the best way to learn Rust?"
2236        );
2237
2238        // Add second message, also without context
2239        let message2_id = thread.update(cx, |thread, cx| {
2240            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2241        });
2242
2243        let message2 =
2244            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2245        assert_eq!(message2.context, "");
2246
2247        // Check that both messages appear in the request
2248        let request = thread.read_with(cx, |thread, cx| {
2249            thread.to_completion_request(RequestKind::Chat, cx)
2250        });
2251
2252        assert_eq!(request.messages.len(), 2);
2253        assert_eq!(
2254            request.messages[0].string_contents(),
2255            "What is the best way to learn Rust?"
2256        );
2257        assert_eq!(
2258            request.messages[1].string_contents(),
2259            "Are there any good books?"
2260        );
2261    }
2262
2263    #[gpui::test]
2264    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2265        init_test_settings(cx);
2266
2267        let project = create_test_project(
2268            cx,
2269            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2270        )
2271        .await;
2272
2273        let (_workspace, _thread_store, thread, context_store) =
2274            setup_test_environment(cx, project.clone()).await;
2275
2276        // Open buffer and add it to context
2277        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2278            .await
2279            .unwrap();
2280
2281        let context =
2282            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2283
2284        // Insert user message with the buffer as context
2285        thread.update(cx, |thread, cx| {
2286            thread.insert_user_message("Explain this code", vec![context], None, cx)
2287        });
2288
2289        // Create a request and check that it doesn't have a stale buffer warning yet
2290        let initial_request = thread.read_with(cx, |thread, cx| {
2291            thread.to_completion_request(RequestKind::Chat, cx)
2292        });
2293
2294        // Make sure we don't have a stale file warning yet
2295        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2296            msg.string_contents()
2297                .contains("These files changed since last read:")
2298        });
2299        assert!(
2300            !has_stale_warning,
2301            "Should not have stale buffer warning before buffer is modified"
2302        );
2303
2304        // Modify the buffer
2305        buffer.update(cx, |buffer, cx| {
2306            // Find a position at the end of line 1
2307            buffer.edit(
2308                [(1..1, "\n    println!(\"Added a new line\");\n")],
2309                None,
2310                cx,
2311            );
2312        });
2313
2314        // Insert another user message without context
2315        thread.update(cx, |thread, cx| {
2316            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2317        });
2318
2319        // Create a new request and check for the stale buffer warning
2320        let new_request = thread.read_with(cx, |thread, cx| {
2321            thread.to_completion_request(RequestKind::Chat, cx)
2322        });
2323
2324        // We should have a stale file warning as the last message
2325        let last_message = new_request
2326            .messages
2327            .last()
2328            .expect("Request should have messages");
2329
2330        // The last message should be the stale buffer notification
2331        assert_eq!(last_message.role, Role::User);
2332
2333        // Check the exact content of the message
2334        let expected_content = "These files changed since last read:\n- code.rs\n";
2335        assert_eq!(
2336            last_message.string_contents(),
2337            expected_content,
2338            "Last message should be exactly the stale buffer notification"
2339        );
2340    }
2341
2342    fn init_test_settings(cx: &mut TestAppContext) {
2343        cx.update(|cx| {
2344            let settings_store = SettingsStore::test(cx);
2345            cx.set_global(settings_store);
2346            language::init(cx);
2347            Project::init_settings(cx);
2348            AssistantSettings::register(cx);
2349            thread_store::init(cx);
2350            workspace::init_settings(cx);
2351            ThemeSettings::register(cx);
2352            ContextServerSettings::register(cx);
2353            EditorSettings::register(cx);
2354        });
2355    }
2356
2357    // Helper to create a test project with test files
2358    async fn create_test_project(
2359        cx: &mut TestAppContext,
2360        files: serde_json::Value,
2361    ) -> Entity<Project> {
2362        let fs = FakeFs::new(cx.executor());
2363        fs.insert_tree(path!("/test"), files).await;
2364        Project::test(fs, [path!("/test").as_ref()], cx).await
2365    }
2366
2367    async fn setup_test_environment(
2368        cx: &mut TestAppContext,
2369        project: Entity<Project>,
2370    ) -> (
2371        Entity<Workspace>,
2372        Entity<ThreadStore>,
2373        Entity<Thread>,
2374        Entity<ContextStore>,
2375    ) {
2376        let (workspace, cx) =
2377            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2378
2379        let thread_store = cx.update(|_, cx| {
2380            ThreadStore::new(
2381                project.clone(),
2382                Arc::default(),
2383                Arc::new(PromptBuilder::new(None).unwrap()),
2384                cx,
2385            )
2386            .unwrap()
2387        });
2388
2389        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2390        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2391
2392        (workspace, thread_store, thread, context_store)
2393    }
2394
2395    async fn add_file_to_context(
2396        project: &Entity<Project>,
2397        context_store: &Entity<ContextStore>,
2398        path: &str,
2399        cx: &mut TestAppContext,
2400    ) -> Result<Entity<language::Buffer>> {
2401        let buffer_path = project
2402            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2403            .unwrap();
2404
2405        let buffer = project
2406            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2407            .await
2408            .unwrap();
2409
2410        context_store
2411            .update(cx, |store, cx| {
2412                store.add_file_from_buffer(buffer.clone(), cx)
2413            })
2414            .await?;
2415
2416        Ok(buffer)
2417    }
2418}