thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5use std::time::Instant;
   6
   7use anyhow::{Context as _, Result, anyhow};
   8use assistant_settings::AssistantSettings;
   9use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
  10use chrono::{DateTime, Utc};
  11use collections::{BTreeMap, HashMap};
  12use feature_flags::{self, FeatureFlagAppExt};
  13use futures::future::Shared;
  14use futures::{FutureExt, StreamExt as _};
  15use git::repository::DiffType;
  16use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  17use language_model::{
  18    ConfiguredModel, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
  19    LanguageModelKnownError, LanguageModelRegistry, LanguageModelRequest,
  20    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  21    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent,
  22    ModelRequestLimitReachedError, PaymentRequiredError, Role, StopReason, TokenUsage,
  23};
  24use project::Project;
  25use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  26use prompt_store::PromptBuilder;
  27use proto::Plan;
  28use schemars::JsonSchema;
  29use serde::{Deserialize, Serialize};
  30use settings::Settings;
  31use thiserror::Error;
  32use util::{ResultExt as _, TryFutureExt as _, post_inc};
  33use uuid::Uuid;
  34
  35use crate::context::{AssistantContext, ContextId, format_context_as_string};
  36use crate::thread_store::{
  37    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  38    SerializedToolUse, SharedProjectContext,
  39};
  40use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  41
  42#[derive(Debug, Clone, Copy)]
  43pub enum RequestKind {
  44    Chat,
  45    /// Used when summarizing a thread.
  46    Summarize,
  47}
  48
  49#[derive(
  50    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  51)]
  52pub struct ThreadId(Arc<str>);
  53
  54impl ThreadId {
  55    pub fn new() -> Self {
  56        Self(Uuid::new_v4().to_string().into())
  57    }
  58}
  59
  60impl std::fmt::Display for ThreadId {
  61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  62        write!(f, "{}", self.0)
  63    }
  64}
  65
  66impl From<&str> for ThreadId {
  67    fn from(value: &str) -> Self {
  68        Self(value.into())
  69    }
  70}
  71
  72#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  73pub struct MessageId(pub(crate) usize);
  74
  75impl MessageId {
  76    fn post_inc(&mut self) -> Self {
  77        Self(post_inc(&mut self.0))
  78    }
  79}
  80
  81/// A message in a [`Thread`].
  82#[derive(Debug, Clone)]
  83pub struct Message {
  84    pub id: MessageId,
  85    pub role: Role,
  86    pub segments: Vec<MessageSegment>,
  87    pub context: String,
  88}
  89
  90impl Message {
  91    /// Returns whether the message contains any meaningful text that should be displayed
  92    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  93    pub fn should_display_content(&self) -> bool {
  94        self.segments.iter().all(|segment| segment.should_display())
  95    }
  96
  97    pub fn push_thinking(&mut self, text: &str) {
  98        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  99            segment.push_str(text);
 100        } else {
 101            self.segments
 102                .push(MessageSegment::Thinking(text.to_string()));
 103        }
 104    }
 105
 106    pub fn push_text(&mut self, text: &str) {
 107        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 108            segment.push_str(text);
 109        } else {
 110            self.segments.push(MessageSegment::Text(text.to_string()));
 111        }
 112    }
 113
 114    pub fn to_string(&self) -> String {
 115        let mut result = String::new();
 116
 117        if !self.context.is_empty() {
 118            result.push_str(&self.context);
 119        }
 120
 121        for segment in &self.segments {
 122            match segment {
 123                MessageSegment::Text(text) => result.push_str(text),
 124                MessageSegment::Thinking(text) => {
 125                    result.push_str("<think>");
 126                    result.push_str(text);
 127                    result.push_str("</think>");
 128                }
 129            }
 130        }
 131
 132        result
 133    }
 134}
 135
 136#[derive(Debug, Clone, PartialEq, Eq)]
 137pub enum MessageSegment {
 138    Text(String),
 139    Thinking(String),
 140}
 141
 142impl MessageSegment {
 143    pub fn text_mut(&mut self) -> &mut String {
 144        match self {
 145            Self::Text(text) => text,
 146            Self::Thinking(text) => text,
 147        }
 148    }
 149
 150    pub fn should_display(&self) -> bool {
 151        // We add USING_TOOL_MARKER when making a request that includes tool uses
 152        // without non-whitespace text around them, and this can cause the model
 153        // to mimic the pattern, so we consider those segments not displayable.
 154        match self {
 155            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 156            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 157        }
 158    }
 159}
 160
 161#[derive(Debug, Clone, Serialize, Deserialize)]
 162pub struct ProjectSnapshot {
 163    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 164    pub unsaved_buffer_paths: Vec<String>,
 165    pub timestamp: DateTime<Utc>,
 166}
 167
 168#[derive(Debug, Clone, Serialize, Deserialize)]
 169pub struct WorktreeSnapshot {
 170    pub worktree_path: String,
 171    pub git_state: Option<GitState>,
 172}
 173
 174#[derive(Debug, Clone, Serialize, Deserialize)]
 175pub struct GitState {
 176    pub remote_url: Option<String>,
 177    pub head_sha: Option<String>,
 178    pub current_branch: Option<String>,
 179    pub diff: Option<String>,
 180}
 181
 182#[derive(Clone)]
 183pub struct ThreadCheckpoint {
 184    message_id: MessageId,
 185    git_checkpoint: GitStoreCheckpoint,
 186}
 187
 188#[derive(Copy, Clone, Debug, PartialEq, Eq)]
 189pub enum ThreadFeedback {
 190    Positive,
 191    Negative,
 192}
 193
 194pub enum LastRestoreCheckpoint {
 195    Pending {
 196        message_id: MessageId,
 197    },
 198    Error {
 199        message_id: MessageId,
 200        error: String,
 201    },
 202}
 203
 204impl LastRestoreCheckpoint {
 205    pub fn message_id(&self) -> MessageId {
 206        match self {
 207            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 208            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 209        }
 210    }
 211}
 212
 213#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 214pub enum DetailedSummaryState {
 215    #[default]
 216    NotGenerated,
 217    Generating {
 218        message_id: MessageId,
 219    },
 220    Generated {
 221        text: SharedString,
 222        message_id: MessageId,
 223    },
 224}
 225
 226#[derive(Default)]
 227pub struct TotalTokenUsage {
 228    pub total: usize,
 229    pub max: usize,
 230    pub ratio: TokenUsageRatio,
 231}
 232
 233#[derive(Debug, Default, PartialEq, Eq)]
 234pub enum TokenUsageRatio {
 235    #[default]
 236    Normal,
 237    Warning,
 238    Exceeded,
 239}
 240
 241/// A thread of conversation with the LLM.
 242pub struct Thread {
 243    id: ThreadId,
 244    updated_at: DateTime<Utc>,
 245    summary: Option<SharedString>,
 246    pending_summary: Task<Option<()>>,
 247    detailed_summary_state: DetailedSummaryState,
 248    messages: Vec<Message>,
 249    next_message_id: MessageId,
 250    context: BTreeMap<ContextId, AssistantContext>,
 251    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 252    project_context: SharedProjectContext,
 253    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 254    completion_count: usize,
 255    pending_completions: Vec<PendingCompletion>,
 256    project: Entity<Project>,
 257    prompt_builder: Arc<PromptBuilder>,
 258    tools: Entity<ToolWorkingSet>,
 259    tool_use: ToolUseState,
 260    action_log: Entity<ActionLog>,
 261    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 262    pending_checkpoint: Option<ThreadCheckpoint>,
 263    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 264    cumulative_token_usage: TokenUsage,
 265    exceeded_window_error: Option<ExceededWindowError>,
 266    feedback: Option<ThreadFeedback>,
 267    message_feedback: HashMap<MessageId, ThreadFeedback>,
 268    last_auto_capture_at: Option<Instant>,
 269}
 270
 271#[derive(Debug, Clone, Serialize, Deserialize)]
 272pub struct ExceededWindowError {
 273    /// Model used when last message exceeded context window
 274    model_id: LanguageModelId,
 275    /// Token count including last message
 276    token_count: usize,
 277}
 278
 279impl Thread {
 280    pub fn new(
 281        project: Entity<Project>,
 282        tools: Entity<ToolWorkingSet>,
 283        prompt_builder: Arc<PromptBuilder>,
 284        system_prompt: SharedProjectContext,
 285        cx: &mut Context<Self>,
 286    ) -> Self {
 287        Self {
 288            id: ThreadId::new(),
 289            updated_at: Utc::now(),
 290            summary: None,
 291            pending_summary: Task::ready(None),
 292            detailed_summary_state: DetailedSummaryState::NotGenerated,
 293            messages: Vec::new(),
 294            next_message_id: MessageId(0),
 295            context: BTreeMap::default(),
 296            context_by_message: HashMap::default(),
 297            project_context: system_prompt,
 298            checkpoints_by_message: HashMap::default(),
 299            completion_count: 0,
 300            pending_completions: Vec::new(),
 301            project: project.clone(),
 302            prompt_builder,
 303            tools: tools.clone(),
 304            last_restore_checkpoint: None,
 305            pending_checkpoint: None,
 306            tool_use: ToolUseState::new(tools.clone()),
 307            action_log: cx.new(|_| ActionLog::new(project.clone())),
 308            initial_project_snapshot: {
 309                let project_snapshot = Self::project_snapshot(project, cx);
 310                cx.foreground_executor()
 311                    .spawn(async move { Some(project_snapshot.await) })
 312                    .shared()
 313            },
 314            cumulative_token_usage: TokenUsage::default(),
 315            exceeded_window_error: None,
 316            feedback: None,
 317            message_feedback: HashMap::default(),
 318            last_auto_capture_at: None,
 319        }
 320    }
 321
 322    pub fn deserialize(
 323        id: ThreadId,
 324        serialized: SerializedThread,
 325        project: Entity<Project>,
 326        tools: Entity<ToolWorkingSet>,
 327        prompt_builder: Arc<PromptBuilder>,
 328        project_context: SharedProjectContext,
 329        cx: &mut Context<Self>,
 330    ) -> Self {
 331        let next_message_id = MessageId(
 332            serialized
 333                .messages
 334                .last()
 335                .map(|message| message.id.0 + 1)
 336                .unwrap_or(0),
 337        );
 338        let tool_use =
 339            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 340
 341        Self {
 342            id,
 343            updated_at: serialized.updated_at,
 344            summary: Some(serialized.summary),
 345            pending_summary: Task::ready(None),
 346            detailed_summary_state: serialized.detailed_summary_state,
 347            messages: serialized
 348                .messages
 349                .into_iter()
 350                .map(|message| Message {
 351                    id: message.id,
 352                    role: message.role,
 353                    segments: message
 354                        .segments
 355                        .into_iter()
 356                        .map(|segment| match segment {
 357                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 358                            SerializedMessageSegment::Thinking { text } => {
 359                                MessageSegment::Thinking(text)
 360                            }
 361                        })
 362                        .collect(),
 363                    context: message.context,
 364                })
 365                .collect(),
 366            next_message_id,
 367            context: BTreeMap::default(),
 368            context_by_message: HashMap::default(),
 369            project_context,
 370            checkpoints_by_message: HashMap::default(),
 371            completion_count: 0,
 372            pending_completions: Vec::new(),
 373            last_restore_checkpoint: None,
 374            pending_checkpoint: None,
 375            project: project.clone(),
 376            prompt_builder,
 377            tools,
 378            tool_use,
 379            action_log: cx.new(|_| ActionLog::new(project)),
 380            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 381            cumulative_token_usage: serialized.cumulative_token_usage,
 382            exceeded_window_error: None,
 383            feedback: None,
 384            message_feedback: HashMap::default(),
 385            last_auto_capture_at: None,
 386        }
 387    }
 388
 389    pub fn id(&self) -> &ThreadId {
 390        &self.id
 391    }
 392
 393    pub fn is_empty(&self) -> bool {
 394        self.messages.is_empty()
 395    }
 396
 397    pub fn updated_at(&self) -> DateTime<Utc> {
 398        self.updated_at
 399    }
 400
 401    pub fn touch_updated_at(&mut self) {
 402        self.updated_at = Utc::now();
 403    }
 404
 405    pub fn summary(&self) -> Option<SharedString> {
 406        self.summary.clone()
 407    }
 408
 409    pub fn project_context(&self) -> SharedProjectContext {
 410        self.project_context.clone()
 411    }
 412
 413    pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread");
 414
 415    pub fn summary_or_default(&self) -> SharedString {
 416        self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY)
 417    }
 418
 419    pub fn set_summary(&mut self, new_summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 420        let Some(current_summary) = &self.summary else {
 421            // Don't allow setting summary until generated
 422            return;
 423        };
 424
 425        let mut new_summary = new_summary.into();
 426
 427        if new_summary.is_empty() {
 428            new_summary = Self::DEFAULT_SUMMARY;
 429        }
 430
 431        if current_summary != &new_summary {
 432            self.summary = Some(new_summary);
 433            cx.emit(ThreadEvent::SummaryChanged);
 434        }
 435    }
 436
 437    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 438        self.latest_detailed_summary()
 439            .unwrap_or_else(|| self.text().into())
 440    }
 441
 442    fn latest_detailed_summary(&self) -> Option<SharedString> {
 443        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 444            Some(text.clone())
 445        } else {
 446            None
 447        }
 448    }
 449
 450    pub fn message(&self, id: MessageId) -> Option<&Message> {
 451        self.messages.iter().find(|message| message.id == id)
 452    }
 453
 454    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 455        self.messages.iter()
 456    }
 457
 458    pub fn is_generating(&self) -> bool {
 459        !self.pending_completions.is_empty() || !self.all_tools_finished()
 460    }
 461
 462    pub fn tools(&self) -> &Entity<ToolWorkingSet> {
 463        &self.tools
 464    }
 465
 466    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 467        self.tool_use
 468            .pending_tool_uses()
 469            .into_iter()
 470            .find(|tool_use| &tool_use.id == id)
 471    }
 472
 473    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 474        self.tool_use
 475            .pending_tool_uses()
 476            .into_iter()
 477            .filter(|tool_use| tool_use.status.needs_confirmation())
 478    }
 479
 480    pub fn has_pending_tool_uses(&self) -> bool {
 481        !self.tool_use.pending_tool_uses().is_empty()
 482    }
 483
 484    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 485        self.checkpoints_by_message.get(&id).cloned()
 486    }
 487
 488    pub fn restore_checkpoint(
 489        &mut self,
 490        checkpoint: ThreadCheckpoint,
 491        cx: &mut Context<Self>,
 492    ) -> Task<Result<()>> {
 493        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 494            message_id: checkpoint.message_id,
 495        });
 496        cx.emit(ThreadEvent::CheckpointChanged);
 497        cx.notify();
 498
 499        let git_store = self.project().read(cx).git_store().clone();
 500        let restore = git_store.update(cx, |git_store, cx| {
 501            git_store.restore_checkpoint(checkpoint.git_checkpoint.clone(), cx)
 502        });
 503
 504        cx.spawn(async move |this, cx| {
 505            let result = restore.await;
 506            this.update(cx, |this, cx| {
 507                if let Err(err) = result.as_ref() {
 508                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 509                        message_id: checkpoint.message_id,
 510                        error: err.to_string(),
 511                    });
 512                } else {
 513                    this.truncate(checkpoint.message_id, cx);
 514                    this.last_restore_checkpoint = None;
 515                }
 516                this.pending_checkpoint = None;
 517                cx.emit(ThreadEvent::CheckpointChanged);
 518                cx.notify();
 519            })?;
 520            result
 521        })
 522    }
 523
 524    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 525        let pending_checkpoint = if self.is_generating() {
 526            return;
 527        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 528            checkpoint
 529        } else {
 530            return;
 531        };
 532
 533        let git_store = self.project.read(cx).git_store().clone();
 534        let final_checkpoint = git_store.update(cx, |git_store, cx| git_store.checkpoint(cx));
 535        cx.spawn(async move |this, cx| match final_checkpoint.await {
 536            Ok(final_checkpoint) => {
 537                let equal = git_store
 538                    .update(cx, |store, cx| {
 539                        store.compare_checkpoints(
 540                            pending_checkpoint.git_checkpoint.clone(),
 541                            final_checkpoint.clone(),
 542                            cx,
 543                        )
 544                    })?
 545                    .await
 546                    .unwrap_or(false);
 547
 548                if equal {
 549                    git_store
 550                        .update(cx, |store, cx| {
 551                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 552                        })?
 553                        .detach();
 554                } else {
 555                    this.update(cx, |this, cx| {
 556                        this.insert_checkpoint(pending_checkpoint, cx)
 557                    })?;
 558                }
 559
 560                git_store
 561                    .update(cx, |store, cx| {
 562                        store.delete_checkpoint(final_checkpoint, cx)
 563                    })?
 564                    .detach();
 565
 566                Ok(())
 567            }
 568            Err(_) => this.update(cx, |this, cx| {
 569                this.insert_checkpoint(pending_checkpoint, cx)
 570            }),
 571        })
 572        .detach();
 573    }
 574
 575    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 576        self.checkpoints_by_message
 577            .insert(checkpoint.message_id, checkpoint);
 578        cx.emit(ThreadEvent::CheckpointChanged);
 579        cx.notify();
 580    }
 581
 582    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 583        self.last_restore_checkpoint.as_ref()
 584    }
 585
 586    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 587        let Some(message_ix) = self
 588            .messages
 589            .iter()
 590            .rposition(|message| message.id == message_id)
 591        else {
 592            return;
 593        };
 594        for deleted_message in self.messages.drain(message_ix..) {
 595            self.context_by_message.remove(&deleted_message.id);
 596            self.checkpoints_by_message.remove(&deleted_message.id);
 597        }
 598        cx.notify();
 599    }
 600
 601    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 602        self.context_by_message
 603            .get(&id)
 604            .into_iter()
 605            .flat_map(|context| {
 606                context
 607                    .iter()
 608                    .filter_map(|context_id| self.context.get(&context_id))
 609            })
 610    }
 611
 612    /// Returns whether all of the tool uses have finished running.
 613    pub fn all_tools_finished(&self) -> bool {
 614        // If the only pending tool uses left are the ones with errors, then
 615        // that means that we've finished running all of the pending tools.
 616        self.tool_use
 617            .pending_tool_uses()
 618            .iter()
 619            .all(|tool_use| tool_use.status.is_error())
 620    }
 621
 622    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 623        self.tool_use.tool_uses_for_message(id, cx)
 624    }
 625
 626    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 627        self.tool_use.tool_results_for_message(id)
 628    }
 629
 630    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 631        self.tool_use.tool_result(id)
 632    }
 633
 634    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 635        self.tool_use.message_has_tool_results(message_id)
 636    }
 637
 638    pub fn insert_user_message(
 639        &mut self,
 640        text: impl Into<String>,
 641        context: Vec<AssistantContext>,
 642        git_checkpoint: Option<GitStoreCheckpoint>,
 643        cx: &mut Context<Self>,
 644    ) -> MessageId {
 645        let text = text.into();
 646
 647        let message_id = self.insert_message(Role::User, vec![MessageSegment::Text(text)], cx);
 648
 649        // Filter out contexts that have already been included in previous messages
 650        let new_context: Vec<_> = context
 651            .into_iter()
 652            .filter(|ctx| !self.context.contains_key(&ctx.id()))
 653            .collect();
 654
 655        if !new_context.is_empty() {
 656            if let Some(context_string) = format_context_as_string(new_context.iter(), cx) {
 657                if let Some(message) = self.messages.iter_mut().find(|m| m.id == message_id) {
 658                    message.context = context_string;
 659                }
 660            }
 661
 662            self.action_log.update(cx, |log, cx| {
 663                // Track all buffers added as context
 664                for ctx in &new_context {
 665                    match ctx {
 666                        AssistantContext::File(file_ctx) => {
 667                            log.buffer_added_as_context(file_ctx.context_buffer.buffer.clone(), cx);
 668                        }
 669                        AssistantContext::Directory(dir_ctx) => {
 670                            for context_buffer in &dir_ctx.context_buffers {
 671                                log.buffer_added_as_context(context_buffer.buffer.clone(), cx);
 672                            }
 673                        }
 674                        AssistantContext::Symbol(symbol_ctx) => {
 675                            log.buffer_added_as_context(
 676                                symbol_ctx.context_symbol.buffer.clone(),
 677                                cx,
 678                            );
 679                        }
 680                        AssistantContext::FetchedUrl(_) | AssistantContext::Thread(_) => {}
 681                    }
 682                }
 683            });
 684        }
 685
 686        let context_ids = new_context
 687            .iter()
 688            .map(|context| context.id())
 689            .collect::<Vec<_>>();
 690        self.context.extend(
 691            new_context
 692                .into_iter()
 693                .map(|context| (context.id(), context)),
 694        );
 695        self.context_by_message.insert(message_id, context_ids);
 696
 697        if let Some(git_checkpoint) = git_checkpoint {
 698            self.pending_checkpoint = Some(ThreadCheckpoint {
 699                message_id,
 700                git_checkpoint,
 701            });
 702        }
 703
 704        self.auto_capture_telemetry(cx);
 705
 706        message_id
 707    }
 708
 709    pub fn insert_message(
 710        &mut self,
 711        role: Role,
 712        segments: Vec<MessageSegment>,
 713        cx: &mut Context<Self>,
 714    ) -> MessageId {
 715        let id = self.next_message_id.post_inc();
 716        self.messages.push(Message {
 717            id,
 718            role,
 719            segments,
 720            context: String::new(),
 721        });
 722        self.touch_updated_at();
 723        cx.emit(ThreadEvent::MessageAdded(id));
 724        id
 725    }
 726
 727    pub fn edit_message(
 728        &mut self,
 729        id: MessageId,
 730        new_role: Role,
 731        new_segments: Vec<MessageSegment>,
 732        cx: &mut Context<Self>,
 733    ) -> bool {
 734        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 735            return false;
 736        };
 737        message.role = new_role;
 738        message.segments = new_segments;
 739        self.touch_updated_at();
 740        cx.emit(ThreadEvent::MessageEdited(id));
 741        true
 742    }
 743
 744    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 745        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 746            return false;
 747        };
 748        self.messages.remove(index);
 749        self.context_by_message.remove(&id);
 750        self.touch_updated_at();
 751        cx.emit(ThreadEvent::MessageDeleted(id));
 752        true
 753    }
 754
 755    /// Returns the representation of this [`Thread`] in a textual form.
 756    ///
 757    /// This is the representation we use when attaching a thread as context to another thread.
 758    pub fn text(&self) -> String {
 759        let mut text = String::new();
 760
 761        for message in &self.messages {
 762            text.push_str(match message.role {
 763                language_model::Role::User => "User:",
 764                language_model::Role::Assistant => "Assistant:",
 765                language_model::Role::System => "System:",
 766            });
 767            text.push('\n');
 768
 769            for segment in &message.segments {
 770                match segment {
 771                    MessageSegment::Text(content) => text.push_str(content),
 772                    MessageSegment::Thinking(content) => {
 773                        text.push_str(&format!("<think>{}</think>", content))
 774                    }
 775                }
 776            }
 777            text.push('\n');
 778        }
 779
 780        text
 781    }
 782
 783    /// Serializes this thread into a format for storage or telemetry.
 784    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 785        let initial_project_snapshot = self.initial_project_snapshot.clone();
 786        cx.spawn(async move |this, cx| {
 787            let initial_project_snapshot = initial_project_snapshot.await;
 788            this.read_with(cx, |this, cx| SerializedThread {
 789                version: SerializedThread::VERSION.to_string(),
 790                summary: this.summary_or_default(),
 791                updated_at: this.updated_at(),
 792                messages: this
 793                    .messages()
 794                    .map(|message| SerializedMessage {
 795                        id: message.id,
 796                        role: message.role,
 797                        segments: message
 798                            .segments
 799                            .iter()
 800                            .map(|segment| match segment {
 801                                MessageSegment::Text(text) => {
 802                                    SerializedMessageSegment::Text { text: text.clone() }
 803                                }
 804                                MessageSegment::Thinking(text) => {
 805                                    SerializedMessageSegment::Thinking { text: text.clone() }
 806                                }
 807                            })
 808                            .collect(),
 809                        tool_uses: this
 810                            .tool_uses_for_message(message.id, cx)
 811                            .into_iter()
 812                            .map(|tool_use| SerializedToolUse {
 813                                id: tool_use.id,
 814                                name: tool_use.name,
 815                                input: tool_use.input,
 816                            })
 817                            .collect(),
 818                        tool_results: this
 819                            .tool_results_for_message(message.id)
 820                            .into_iter()
 821                            .map(|tool_result| SerializedToolResult {
 822                                tool_use_id: tool_result.tool_use_id.clone(),
 823                                is_error: tool_result.is_error,
 824                                content: tool_result.content.clone(),
 825                            })
 826                            .collect(),
 827                        context: message.context.clone(),
 828                    })
 829                    .collect(),
 830                initial_project_snapshot,
 831                cumulative_token_usage: this.cumulative_token_usage,
 832                detailed_summary_state: this.detailed_summary_state.clone(),
 833                exceeded_window_error: this.exceeded_window_error.clone(),
 834            })
 835        })
 836    }
 837
 838    pub fn send_to_model(
 839        &mut self,
 840        model: Arc<dyn LanguageModel>,
 841        request_kind: RequestKind,
 842        cx: &mut Context<Self>,
 843    ) {
 844        let mut request = self.to_completion_request(request_kind, cx);
 845        if model.supports_tools() {
 846            request.tools = {
 847                let mut tools = Vec::new();
 848                tools.extend(
 849                    self.tools()
 850                        .read(cx)
 851                        .enabled_tools(cx)
 852                        .into_iter()
 853                        .filter_map(|tool| {
 854                            // Skip tools that cannot be supported
 855                            let input_schema = tool.input_schema(model.tool_input_format()).ok()?;
 856                            Some(LanguageModelRequestTool {
 857                                name: tool.name(),
 858                                description: tool.description(),
 859                                input_schema,
 860                            })
 861                        }),
 862                );
 863
 864                tools
 865            };
 866        }
 867
 868        self.stream_completion(request, model, cx);
 869    }
 870
 871    pub fn used_tools_since_last_user_message(&self) -> bool {
 872        for message in self.messages.iter().rev() {
 873            if self.tool_use.message_has_tool_results(message.id) {
 874                return true;
 875            } else if message.role == Role::User {
 876                return false;
 877            }
 878        }
 879
 880        false
 881    }
 882
 883    pub fn to_completion_request(
 884        &self,
 885        request_kind: RequestKind,
 886        cx: &App,
 887    ) -> LanguageModelRequest {
 888        let mut request = LanguageModelRequest {
 889            messages: vec![],
 890            tools: Vec::new(),
 891            stop: Vec::new(),
 892            temperature: None,
 893        };
 894
 895        if let Some(project_context) = self.project_context.borrow().as_ref() {
 896            if let Some(system_prompt) = self
 897                .prompt_builder
 898                .generate_assistant_system_prompt(project_context)
 899                .context("failed to generate assistant system prompt")
 900                .log_err()
 901            {
 902                request.messages.push(LanguageModelRequestMessage {
 903                    role: Role::System,
 904                    content: vec![MessageContent::Text(system_prompt)],
 905                    cache: true,
 906                });
 907            }
 908        } else {
 909            log::error!("project_context not set.")
 910        }
 911
 912        for message in &self.messages {
 913            let mut request_message = LanguageModelRequestMessage {
 914                role: message.role,
 915                content: Vec::new(),
 916                cache: false,
 917            };
 918
 919            match request_kind {
 920                RequestKind::Chat => {
 921                    self.tool_use
 922                        .attach_tool_results(message.id, &mut request_message);
 923                }
 924                RequestKind::Summarize => {
 925                    // We don't care about tool use during summarization.
 926                    if self.tool_use.message_has_tool_results(message.id) {
 927                        continue;
 928                    }
 929                }
 930            }
 931
 932            if !message.segments.is_empty() {
 933                request_message
 934                    .content
 935                    .push(MessageContent::Text(message.to_string()));
 936            }
 937
 938            match request_kind {
 939                RequestKind::Chat => {
 940                    self.tool_use
 941                        .attach_tool_uses(message.id, &mut request_message);
 942                }
 943                RequestKind::Summarize => {
 944                    // We don't care about tool use during summarization.
 945                }
 946            };
 947
 948            request.messages.push(request_message);
 949        }
 950
 951        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 952        if let Some(last) = request.messages.last_mut() {
 953            last.cache = true;
 954        }
 955
 956        self.attached_tracked_files_state(&mut request.messages, cx);
 957
 958        request
 959    }
 960
 961    fn attached_tracked_files_state(
 962        &self,
 963        messages: &mut Vec<LanguageModelRequestMessage>,
 964        cx: &App,
 965    ) {
 966        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 967
 968        let mut stale_message = String::new();
 969
 970        let action_log = self.action_log.read(cx);
 971
 972        for stale_file in action_log.stale_buffers(cx) {
 973            let Some(file) = stale_file.read(cx).file() else {
 974                continue;
 975            };
 976
 977            if stale_message.is_empty() {
 978                write!(&mut stale_message, "{}\n", STALE_FILES_HEADER).ok();
 979            }
 980
 981            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 982        }
 983
 984        let mut content = Vec::with_capacity(2);
 985
 986        if !stale_message.is_empty() {
 987            content.push(stale_message.into());
 988        }
 989
 990        if action_log.has_edited_files_since_project_diagnostics_check() {
 991            content.push(
 992                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 993                and fix all errors AND warnings you introduced! \
 994                DO NOT mention you're going to do this until you're done."
 995                    .into(),
 996            );
 997        }
 998
 999        if !content.is_empty() {
1000            let context_message = LanguageModelRequestMessage {
1001                role: Role::User,
1002                content,
1003                cache: false,
1004            };
1005
1006            messages.push(context_message);
1007        }
1008    }
1009
1010    pub fn stream_completion(
1011        &mut self,
1012        request: LanguageModelRequest,
1013        model: Arc<dyn LanguageModel>,
1014        cx: &mut Context<Self>,
1015    ) {
1016        let pending_completion_id = post_inc(&mut self.completion_count);
1017
1018        let task = cx.spawn(async move |thread, cx| {
1019            let stream = model.stream_completion(request, &cx);
1020            let initial_token_usage =
1021                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage);
1022            let stream_completion = async {
1023                let mut events = stream.await?;
1024                let mut stop_reason = StopReason::EndTurn;
1025                let mut current_token_usage = TokenUsage::default();
1026
1027                while let Some(event) = events.next().await {
1028                    let event = event?;
1029
1030                    thread.update(cx, |thread, cx| {
1031                        match event {
1032                            LanguageModelCompletionEvent::StartMessage { .. } => {
1033                                thread.insert_message(
1034                                    Role::Assistant,
1035                                    vec![MessageSegment::Text(String::new())],
1036                                    cx,
1037                                );
1038                            }
1039                            LanguageModelCompletionEvent::Stop(reason) => {
1040                                stop_reason = reason;
1041                            }
1042                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1043                                thread.cumulative_token_usage = thread.cumulative_token_usage
1044                                    + token_usage
1045                                    - current_token_usage;
1046                                current_token_usage = token_usage;
1047                            }
1048                            LanguageModelCompletionEvent::Text(chunk) => {
1049                                if let Some(last_message) = thread.messages.last_mut() {
1050                                    if last_message.role == Role::Assistant {
1051                                        last_message.push_text(&chunk);
1052                                        cx.emit(ThreadEvent::StreamedAssistantText(
1053                                            last_message.id,
1054                                            chunk,
1055                                        ));
1056                                    } else {
1057                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1058                                        // of a new Assistant response.
1059                                        //
1060                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1061                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1062                                        thread.insert_message(
1063                                            Role::Assistant,
1064                                            vec![MessageSegment::Text(chunk.to_string())],
1065                                            cx,
1066                                        );
1067                                    };
1068                                }
1069                            }
1070                            LanguageModelCompletionEvent::Thinking(chunk) => {
1071                                if let Some(last_message) = thread.messages.last_mut() {
1072                                    if last_message.role == Role::Assistant {
1073                                        last_message.push_thinking(&chunk);
1074                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1075                                            last_message.id,
1076                                            chunk,
1077                                        ));
1078                                    } else {
1079                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1080                                        // of a new Assistant response.
1081                                        //
1082                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1083                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1084                                        thread.insert_message(
1085                                            Role::Assistant,
1086                                            vec![MessageSegment::Thinking(chunk.to_string())],
1087                                            cx,
1088                                        );
1089                                    };
1090                                }
1091                            }
1092                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1093                                let last_assistant_message_id = thread
1094                                    .messages
1095                                    .iter_mut()
1096                                    .rfind(|message| message.role == Role::Assistant)
1097                                    .map(|message| message.id)
1098                                    .unwrap_or_else(|| {
1099                                        thread.insert_message(Role::Assistant, vec![], cx)
1100                                    });
1101
1102                                thread.tool_use.request_tool_use(
1103                                    last_assistant_message_id,
1104                                    tool_use,
1105                                    cx,
1106                                );
1107                            }
1108                        }
1109
1110                        thread.touch_updated_at();
1111                        cx.emit(ThreadEvent::StreamedCompletion);
1112                        cx.notify();
1113
1114                        thread.auto_capture_telemetry(cx);
1115                    })?;
1116
1117                    smol::future::yield_now().await;
1118                }
1119
1120                thread.update(cx, |thread, cx| {
1121                    thread
1122                        .pending_completions
1123                        .retain(|completion| completion.id != pending_completion_id);
1124
1125                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1126                        thread.summarize(cx);
1127                    }
1128                })?;
1129
1130                anyhow::Ok(stop_reason)
1131            };
1132
1133            let result = stream_completion.await;
1134
1135            thread
1136                .update(cx, |thread, cx| {
1137                    thread.finalize_pending_checkpoint(cx);
1138                    match result.as_ref() {
1139                        Ok(stop_reason) => match stop_reason {
1140                            StopReason::ToolUse => {
1141                                let tool_uses = thread.use_pending_tools(cx);
1142                                cx.emit(ThreadEvent::UsePendingTools { tool_uses });
1143                            }
1144                            StopReason::EndTurn => {}
1145                            StopReason::MaxTokens => {}
1146                        },
1147                        Err(error) => {
1148                            if error.is::<PaymentRequiredError>() {
1149                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1150                            } else if error.is::<MaxMonthlySpendReachedError>() {
1151                                cx.emit(ThreadEvent::ShowError(
1152                                    ThreadError::MaxMonthlySpendReached,
1153                                ));
1154                            } else if let Some(error) =
1155                                error.downcast_ref::<ModelRequestLimitReachedError>()
1156                            {
1157                                cx.emit(ThreadEvent::ShowError(
1158                                    ThreadError::ModelRequestLimitReached { plan: error.plan },
1159                                ));
1160                            } else if let Some(known_error) =
1161                                error.downcast_ref::<LanguageModelKnownError>()
1162                            {
1163                                match known_error {
1164                                    LanguageModelKnownError::ContextWindowLimitExceeded {
1165                                        tokens,
1166                                    } => {
1167                                        thread.exceeded_window_error = Some(ExceededWindowError {
1168                                            model_id: model.id(),
1169                                            token_count: *tokens,
1170                                        });
1171                                        cx.notify();
1172                                    }
1173                                }
1174                            } else {
1175                                let error_message = error
1176                                    .chain()
1177                                    .map(|err| err.to_string())
1178                                    .collect::<Vec<_>>()
1179                                    .join("\n");
1180                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1181                                    header: "Error interacting with language model".into(),
1182                                    message: SharedString::from(error_message.clone()),
1183                                }));
1184                            }
1185
1186                            thread.cancel_last_completion(cx);
1187                        }
1188                    }
1189                    cx.emit(ThreadEvent::Stopped(result.map_err(Arc::new)));
1190
1191                    thread.auto_capture_telemetry(cx);
1192
1193                    if let Ok(initial_usage) = initial_token_usage {
1194                        let usage = thread.cumulative_token_usage - initial_usage;
1195
1196                        telemetry::event!(
1197                            "Assistant Thread Completion",
1198                            thread_id = thread.id().to_string(),
1199                            model = model.telemetry_id(),
1200                            model_provider = model.provider_id().to_string(),
1201                            input_tokens = usage.input_tokens,
1202                            output_tokens = usage.output_tokens,
1203                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1204                            cache_read_input_tokens = usage.cache_read_input_tokens,
1205                        );
1206                    }
1207                })
1208                .ok();
1209        });
1210
1211        self.pending_completions.push(PendingCompletion {
1212            id: pending_completion_id,
1213            _task: task,
1214        });
1215    }
1216
1217    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1218        let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else {
1219            return;
1220        };
1221
1222        if !model.provider.is_authenticated(cx) {
1223            return;
1224        }
1225
1226        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1227        request.messages.push(LanguageModelRequestMessage {
1228            role: Role::User,
1229            content: vec![
1230                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1231                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1232                 If the conversation is about a specific subject, include it in the title. \
1233                 Be descriptive. DO NOT speak in the first person."
1234                    .into(),
1235            ],
1236            cache: false,
1237        });
1238
1239        self.pending_summary = cx.spawn(async move |this, cx| {
1240            async move {
1241                let stream = model.model.stream_completion_text(request, &cx);
1242                let mut messages = stream.await?;
1243
1244                let mut new_summary = String::new();
1245                while let Some(message) = messages.stream.next().await {
1246                    let text = message?;
1247                    let mut lines = text.lines();
1248                    new_summary.extend(lines.next());
1249
1250                    // Stop if the LLM generated multiple lines.
1251                    if lines.next().is_some() {
1252                        break;
1253                    }
1254                }
1255
1256                this.update(cx, |this, cx| {
1257                    if !new_summary.is_empty() {
1258                        this.summary = Some(new_summary.into());
1259                    }
1260
1261                    cx.emit(ThreadEvent::SummaryGenerated);
1262                })?;
1263
1264                anyhow::Ok(())
1265            }
1266            .log_err()
1267            .await
1268        });
1269    }
1270
1271    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1272        let last_message_id = self.messages.last().map(|message| message.id)?;
1273
1274        match &self.detailed_summary_state {
1275            DetailedSummaryState::Generating { message_id, .. }
1276            | DetailedSummaryState::Generated { message_id, .. }
1277                if *message_id == last_message_id =>
1278            {
1279                // Already up-to-date
1280                return None;
1281            }
1282            _ => {}
1283        }
1284
1285        let ConfiguredModel { model, provider } =
1286            LanguageModelRegistry::read_global(cx).thread_summary_model()?;
1287
1288        if !provider.is_authenticated(cx) {
1289            return None;
1290        }
1291
1292        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1293
1294        request.messages.push(LanguageModelRequestMessage {
1295            role: Role::User,
1296            content: vec![
1297                "Generate a detailed summary of this conversation. Include:\n\
1298                1. A brief overview of what was discussed\n\
1299                2. Key facts or information discovered\n\
1300                3. Outcomes or conclusions reached\n\
1301                4. Any action items or next steps if any\n\
1302                Format it in Markdown with headings and bullet points."
1303                    .into(),
1304            ],
1305            cache: false,
1306        });
1307
1308        let task = cx.spawn(async move |thread, cx| {
1309            let stream = model.stream_completion_text(request, &cx);
1310            let Some(mut messages) = stream.await.log_err() else {
1311                thread
1312                    .update(cx, |this, _cx| {
1313                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1314                    })
1315                    .log_err();
1316
1317                return;
1318            };
1319
1320            let mut new_detailed_summary = String::new();
1321
1322            while let Some(chunk) = messages.stream.next().await {
1323                if let Some(chunk) = chunk.log_err() {
1324                    new_detailed_summary.push_str(&chunk);
1325                }
1326            }
1327
1328            thread
1329                .update(cx, |this, _cx| {
1330                    this.detailed_summary_state = DetailedSummaryState::Generated {
1331                        text: new_detailed_summary.into(),
1332                        message_id: last_message_id,
1333                    };
1334                })
1335                .log_err();
1336        });
1337
1338        self.detailed_summary_state = DetailedSummaryState::Generating {
1339            message_id: last_message_id,
1340        };
1341
1342        Some(task)
1343    }
1344
1345    pub fn is_generating_detailed_summary(&self) -> bool {
1346        matches!(
1347            self.detailed_summary_state,
1348            DetailedSummaryState::Generating { .. }
1349        )
1350    }
1351
1352    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) -> Vec<PendingToolUse> {
1353        self.auto_capture_telemetry(cx);
1354        let request = self.to_completion_request(RequestKind::Chat, cx);
1355        let messages = Arc::new(request.messages);
1356        let pending_tool_uses = self
1357            .tool_use
1358            .pending_tool_uses()
1359            .into_iter()
1360            .filter(|tool_use| tool_use.status.is_idle())
1361            .cloned()
1362            .collect::<Vec<_>>();
1363
1364        for tool_use in pending_tool_uses.iter() {
1365            if let Some(tool) = self.tools.read(cx).tool(&tool_use.name, cx) {
1366                if tool.needs_confirmation(&tool_use.input, cx)
1367                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1368                {
1369                    self.tool_use.confirm_tool_use(
1370                        tool_use.id.clone(),
1371                        tool_use.ui_text.clone(),
1372                        tool_use.input.clone(),
1373                        messages.clone(),
1374                        tool,
1375                    );
1376                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1377                } else {
1378                    self.run_tool(
1379                        tool_use.id.clone(),
1380                        tool_use.ui_text.clone(),
1381                        tool_use.input.clone(),
1382                        &messages,
1383                        tool,
1384                        cx,
1385                    );
1386                }
1387            }
1388        }
1389
1390        pending_tool_uses
1391    }
1392
1393    pub fn run_tool(
1394        &mut self,
1395        tool_use_id: LanguageModelToolUseId,
1396        ui_text: impl Into<SharedString>,
1397        input: serde_json::Value,
1398        messages: &[LanguageModelRequestMessage],
1399        tool: Arc<dyn Tool>,
1400        cx: &mut Context<Thread>,
1401    ) {
1402        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1403        self.tool_use
1404            .run_pending_tool(tool_use_id, ui_text.into(), task);
1405    }
1406
1407    fn spawn_tool_use(
1408        &mut self,
1409        tool_use_id: LanguageModelToolUseId,
1410        messages: &[LanguageModelRequestMessage],
1411        input: serde_json::Value,
1412        tool: Arc<dyn Tool>,
1413        cx: &mut Context<Thread>,
1414    ) -> Task<()> {
1415        let tool_name: Arc<str> = tool.name().into();
1416
1417        let tool_result = if self.tools.read(cx).is_disabled(&tool.source(), &tool_name) {
1418            Task::ready(Err(anyhow!("tool is disabled: {tool_name}"))).into()
1419        } else {
1420            tool.run(
1421                input,
1422                messages,
1423                self.project.clone(),
1424                self.action_log.clone(),
1425                cx,
1426            )
1427        };
1428
1429        cx.spawn({
1430            async move |thread: WeakEntity<Thread>, cx| {
1431                let output = tool_result.output.await;
1432
1433                thread
1434                    .update(cx, |thread, cx| {
1435                        let pending_tool_use = thread.tool_use.insert_tool_output(
1436                            tool_use_id.clone(),
1437                            tool_name,
1438                            output,
1439                            cx,
1440                        );
1441                        thread.tool_finished(tool_use_id, pending_tool_use, false, cx);
1442                    })
1443                    .ok();
1444            }
1445        })
1446    }
1447
1448    fn tool_finished(
1449        &mut self,
1450        tool_use_id: LanguageModelToolUseId,
1451        pending_tool_use: Option<PendingToolUse>,
1452        canceled: bool,
1453        cx: &mut Context<Self>,
1454    ) {
1455        if self.all_tools_finished() {
1456            let model_registry = LanguageModelRegistry::read_global(cx);
1457            if let Some(ConfiguredModel { model, .. }) = model_registry.default_model() {
1458                self.attach_tool_results(cx);
1459                if !canceled {
1460                    self.send_to_model(model, RequestKind::Chat, cx);
1461                }
1462            }
1463        }
1464
1465        cx.emit(ThreadEvent::ToolFinished {
1466            tool_use_id,
1467            pending_tool_use,
1468        });
1469    }
1470
1471    pub fn attach_tool_results(&mut self, cx: &mut Context<Self>) {
1472        // Insert a user message to contain the tool results.
1473        self.insert_user_message(
1474            // TODO: Sending up a user message without any content results in the model sending back
1475            // responses that also don't have any content. We currently don't handle this case well,
1476            // so for now we provide some text to keep the model on track.
1477            "Here are the tool results.",
1478            Vec::new(),
1479            None,
1480            cx,
1481        );
1482    }
1483
1484    /// Cancels the last pending completion, if there are any pending.
1485    ///
1486    /// Returns whether a completion was canceled.
1487    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1488        let canceled = if self.pending_completions.pop().is_some() {
1489            true
1490        } else {
1491            let mut canceled = false;
1492            for pending_tool_use in self.tool_use.cancel_pending() {
1493                canceled = true;
1494                self.tool_finished(
1495                    pending_tool_use.id.clone(),
1496                    Some(pending_tool_use),
1497                    true,
1498                    cx,
1499                );
1500            }
1501            canceled
1502        };
1503        self.finalize_pending_checkpoint(cx);
1504        canceled
1505    }
1506
1507    pub fn feedback(&self) -> Option<ThreadFeedback> {
1508        self.feedback
1509    }
1510
1511    pub fn message_feedback(&self, message_id: MessageId) -> Option<ThreadFeedback> {
1512        self.message_feedback.get(&message_id).copied()
1513    }
1514
1515    pub fn report_message_feedback(
1516        &mut self,
1517        message_id: MessageId,
1518        feedback: ThreadFeedback,
1519        cx: &mut Context<Self>,
1520    ) -> Task<Result<()>> {
1521        if self.message_feedback.get(&message_id) == Some(&feedback) {
1522            return Task::ready(Ok(()));
1523        }
1524
1525        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1526        let serialized_thread = self.serialize(cx);
1527        let thread_id = self.id().clone();
1528        let client = self.project.read(cx).client();
1529
1530        let enabled_tool_names: Vec<String> = self
1531            .tools()
1532            .read(cx)
1533            .enabled_tools(cx)
1534            .iter()
1535            .map(|tool| tool.name().to_string())
1536            .collect();
1537
1538        self.message_feedback.insert(message_id, feedback);
1539
1540        cx.notify();
1541
1542        let message_content = self
1543            .message(message_id)
1544            .map(|msg| msg.to_string())
1545            .unwrap_or_default();
1546
1547        cx.background_spawn(async move {
1548            let final_project_snapshot = final_project_snapshot.await;
1549            let serialized_thread = serialized_thread.await?;
1550            let thread_data =
1551                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1552
1553            let rating = match feedback {
1554                ThreadFeedback::Positive => "positive",
1555                ThreadFeedback::Negative => "negative",
1556            };
1557            telemetry::event!(
1558                "Assistant Thread Rated",
1559                rating,
1560                thread_id,
1561                enabled_tool_names,
1562                message_id = message_id.0,
1563                message_content,
1564                thread_data,
1565                final_project_snapshot
1566            );
1567            client.telemetry().flush_events();
1568
1569            Ok(())
1570        })
1571    }
1572
1573    pub fn report_feedback(
1574        &mut self,
1575        feedback: ThreadFeedback,
1576        cx: &mut Context<Self>,
1577    ) -> Task<Result<()>> {
1578        let last_assistant_message_id = self
1579            .messages
1580            .iter()
1581            .rev()
1582            .find(|msg| msg.role == Role::Assistant)
1583            .map(|msg| msg.id);
1584
1585        if let Some(message_id) = last_assistant_message_id {
1586            self.report_message_feedback(message_id, feedback, cx)
1587        } else {
1588            let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1589            let serialized_thread = self.serialize(cx);
1590            let thread_id = self.id().clone();
1591            let client = self.project.read(cx).client();
1592            self.feedback = Some(feedback);
1593            cx.notify();
1594
1595            cx.background_spawn(async move {
1596                let final_project_snapshot = final_project_snapshot.await;
1597                let serialized_thread = serialized_thread.await?;
1598                let thread_data = serde_json::to_value(serialized_thread)
1599                    .unwrap_or_else(|_| serde_json::Value::Null);
1600
1601                let rating = match feedback {
1602                    ThreadFeedback::Positive => "positive",
1603                    ThreadFeedback::Negative => "negative",
1604                };
1605                telemetry::event!(
1606                    "Assistant Thread Rated",
1607                    rating,
1608                    thread_id,
1609                    thread_data,
1610                    final_project_snapshot
1611                );
1612                client.telemetry().flush_events();
1613
1614                Ok(())
1615            })
1616        }
1617    }
1618
1619    /// Create a snapshot of the current project state including git information and unsaved buffers.
1620    fn project_snapshot(
1621        project: Entity<Project>,
1622        cx: &mut Context<Self>,
1623    ) -> Task<Arc<ProjectSnapshot>> {
1624        let git_store = project.read(cx).git_store().clone();
1625        let worktree_snapshots: Vec<_> = project
1626            .read(cx)
1627            .visible_worktrees(cx)
1628            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1629            .collect();
1630
1631        cx.spawn(async move |_, cx| {
1632            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1633
1634            let mut unsaved_buffers = Vec::new();
1635            cx.update(|app_cx| {
1636                let buffer_store = project.read(app_cx).buffer_store();
1637                for buffer_handle in buffer_store.read(app_cx).buffers() {
1638                    let buffer = buffer_handle.read(app_cx);
1639                    if buffer.is_dirty() {
1640                        if let Some(file) = buffer.file() {
1641                            let path = file.path().to_string_lossy().to_string();
1642                            unsaved_buffers.push(path);
1643                        }
1644                    }
1645                }
1646            })
1647            .ok();
1648
1649            Arc::new(ProjectSnapshot {
1650                worktree_snapshots,
1651                unsaved_buffer_paths: unsaved_buffers,
1652                timestamp: Utc::now(),
1653            })
1654        })
1655    }
1656
1657    fn worktree_snapshot(
1658        worktree: Entity<project::Worktree>,
1659        git_store: Entity<GitStore>,
1660        cx: &App,
1661    ) -> Task<WorktreeSnapshot> {
1662        cx.spawn(async move |cx| {
1663            // Get worktree path and snapshot
1664            let worktree_info = cx.update(|app_cx| {
1665                let worktree = worktree.read(app_cx);
1666                let path = worktree.abs_path().to_string_lossy().to_string();
1667                let snapshot = worktree.snapshot();
1668                (path, snapshot)
1669            });
1670
1671            let Ok((worktree_path, _snapshot)) = worktree_info else {
1672                return WorktreeSnapshot {
1673                    worktree_path: String::new(),
1674                    git_state: None,
1675                };
1676            };
1677
1678            let git_state = git_store
1679                .update(cx, |git_store, cx| {
1680                    git_store
1681                        .repositories()
1682                        .values()
1683                        .find(|repo| {
1684                            repo.read(cx)
1685                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1686                                .is_some()
1687                        })
1688                        .cloned()
1689                })
1690                .ok()
1691                .flatten()
1692                .map(|repo| {
1693                    repo.update(cx, |repo, _| {
1694                        let current_branch =
1695                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1696                        repo.send_job(None, |state, _| async move {
1697                            let RepositoryState::Local { backend, .. } = state else {
1698                                return GitState {
1699                                    remote_url: None,
1700                                    head_sha: None,
1701                                    current_branch,
1702                                    diff: None,
1703                                };
1704                            };
1705
1706                            let remote_url = backend.remote_url("origin");
1707                            let head_sha = backend.head_sha();
1708                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1709
1710                            GitState {
1711                                remote_url,
1712                                head_sha,
1713                                current_branch,
1714                                diff,
1715                            }
1716                        })
1717                    })
1718                });
1719
1720            let git_state = match git_state {
1721                Some(git_state) => match git_state.ok() {
1722                    Some(git_state) => git_state.await.ok(),
1723                    None => None,
1724                },
1725                None => None,
1726            };
1727
1728            WorktreeSnapshot {
1729                worktree_path,
1730                git_state,
1731            }
1732        })
1733    }
1734
1735    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1736        let mut markdown = Vec::new();
1737
1738        if let Some(summary) = self.summary() {
1739            writeln!(markdown, "# {summary}\n")?;
1740        };
1741
1742        for message in self.messages() {
1743            writeln!(
1744                markdown,
1745                "## {role}\n",
1746                role = match message.role {
1747                    Role::User => "User",
1748                    Role::Assistant => "Assistant",
1749                    Role::System => "System",
1750                }
1751            )?;
1752
1753            if !message.context.is_empty() {
1754                writeln!(markdown, "{}", message.context)?;
1755            }
1756
1757            for segment in &message.segments {
1758                match segment {
1759                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1760                    MessageSegment::Thinking(text) => {
1761                        writeln!(markdown, "<think>{}</think>\n", text)?
1762                    }
1763                }
1764            }
1765
1766            for tool_use in self.tool_uses_for_message(message.id, cx) {
1767                writeln!(
1768                    markdown,
1769                    "**Use Tool: {} ({})**",
1770                    tool_use.name, tool_use.id
1771                )?;
1772                writeln!(markdown, "```json")?;
1773                writeln!(
1774                    markdown,
1775                    "{}",
1776                    serde_json::to_string_pretty(&tool_use.input)?
1777                )?;
1778                writeln!(markdown, "```")?;
1779            }
1780
1781            for tool_result in self.tool_results_for_message(message.id) {
1782                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1783                if tool_result.is_error {
1784                    write!(markdown, " (Error)")?;
1785                }
1786
1787                writeln!(markdown, "**\n")?;
1788                writeln!(markdown, "{}", tool_result.content)?;
1789            }
1790        }
1791
1792        Ok(String::from_utf8_lossy(&markdown).to_string())
1793    }
1794
1795    pub fn keep_edits_in_range(
1796        &mut self,
1797        buffer: Entity<language::Buffer>,
1798        buffer_range: Range<language::Anchor>,
1799        cx: &mut Context<Self>,
1800    ) {
1801        self.action_log.update(cx, |action_log, cx| {
1802            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1803        });
1804    }
1805
1806    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1807        self.action_log
1808            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1809    }
1810
1811    pub fn reject_edits_in_ranges(
1812        &mut self,
1813        buffer: Entity<language::Buffer>,
1814        buffer_ranges: Vec<Range<language::Anchor>>,
1815        cx: &mut Context<Self>,
1816    ) -> Task<Result<()>> {
1817        self.action_log.update(cx, |action_log, cx| {
1818            action_log.reject_edits_in_ranges(buffer, buffer_ranges, cx)
1819        })
1820    }
1821
1822    pub fn action_log(&self) -> &Entity<ActionLog> {
1823        &self.action_log
1824    }
1825
1826    pub fn project(&self) -> &Entity<Project> {
1827        &self.project
1828    }
1829
1830    pub fn auto_capture_telemetry(&mut self, cx: &mut Context<Self>) {
1831        if !cx.has_flag::<feature_flags::ThreadAutoCapture>() {
1832            return;
1833        }
1834
1835        let now = Instant::now();
1836        if let Some(last) = self.last_auto_capture_at {
1837            if now.duration_since(last).as_secs() < 10 {
1838                return;
1839            }
1840        }
1841
1842        self.last_auto_capture_at = Some(now);
1843
1844        let thread_id = self.id().clone();
1845        let github_login = self
1846            .project
1847            .read(cx)
1848            .user_store()
1849            .read(cx)
1850            .current_user()
1851            .map(|user| user.github_login.clone());
1852        let client = self.project.read(cx).client().clone();
1853        let serialize_task = self.serialize(cx);
1854
1855        cx.background_executor()
1856            .spawn(async move {
1857                if let Ok(serialized_thread) = serialize_task.await {
1858                    if let Ok(thread_data) = serde_json::to_value(serialized_thread) {
1859                        telemetry::event!(
1860                            "Agent Thread Auto-Captured",
1861                            thread_id = thread_id.to_string(),
1862                            thread_data = thread_data,
1863                            auto_capture_reason = "tracked_user",
1864                            github_login = github_login
1865                        );
1866
1867                        client.telemetry().flush_events();
1868                    }
1869                }
1870            })
1871            .detach();
1872    }
1873
1874    pub fn cumulative_token_usage(&self) -> TokenUsage {
1875        self.cumulative_token_usage
1876    }
1877
1878    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1879        let model_registry = LanguageModelRegistry::read_global(cx);
1880        let Some(model) = model_registry.default_model() else {
1881            return TotalTokenUsage::default();
1882        };
1883
1884        let max = model.model.max_token_count();
1885
1886        if let Some(exceeded_error) = &self.exceeded_window_error {
1887            if model.model.id() == exceeded_error.model_id {
1888                return TotalTokenUsage {
1889                    total: exceeded_error.token_count,
1890                    max,
1891                    ratio: TokenUsageRatio::Exceeded,
1892                };
1893            }
1894        }
1895
1896        #[cfg(debug_assertions)]
1897        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1898            .unwrap_or("0.8".to_string())
1899            .parse()
1900            .unwrap();
1901        #[cfg(not(debug_assertions))]
1902        let warning_threshold: f32 = 0.8;
1903
1904        let total = self.cumulative_token_usage.total_tokens() as usize;
1905
1906        let ratio = if total >= max {
1907            TokenUsageRatio::Exceeded
1908        } else if total as f32 / max as f32 >= warning_threshold {
1909            TokenUsageRatio::Warning
1910        } else {
1911            TokenUsageRatio::Normal
1912        };
1913
1914        TotalTokenUsage { total, max, ratio }
1915    }
1916
1917    pub fn deny_tool_use(
1918        &mut self,
1919        tool_use_id: LanguageModelToolUseId,
1920        tool_name: Arc<str>,
1921        cx: &mut Context<Self>,
1922    ) {
1923        let err = Err(anyhow::anyhow!(
1924            "Permission to run tool action denied by user"
1925        ));
1926
1927        self.tool_use
1928            .insert_tool_output(tool_use_id.clone(), tool_name, err, cx);
1929        self.tool_finished(tool_use_id.clone(), None, true, cx);
1930    }
1931}
1932
1933#[derive(Debug, Clone, Error)]
1934pub enum ThreadError {
1935    #[error("Payment required")]
1936    PaymentRequired,
1937    #[error("Max monthly spend reached")]
1938    MaxMonthlySpendReached,
1939    #[error("Model request limit reached")]
1940    ModelRequestLimitReached { plan: Plan },
1941    #[error("Message {header}: {message}")]
1942    Message {
1943        header: SharedString,
1944        message: SharedString,
1945    },
1946}
1947
1948#[derive(Debug, Clone)]
1949pub enum ThreadEvent {
1950    ShowError(ThreadError),
1951    StreamedCompletion,
1952    StreamedAssistantText(MessageId, String),
1953    StreamedAssistantThinking(MessageId, String),
1954    Stopped(Result<StopReason, Arc<anyhow::Error>>),
1955    MessageAdded(MessageId),
1956    MessageEdited(MessageId),
1957    MessageDeleted(MessageId),
1958    SummaryGenerated,
1959    SummaryChanged,
1960    UsePendingTools {
1961        tool_uses: Vec<PendingToolUse>,
1962    },
1963    ToolFinished {
1964        #[allow(unused)]
1965        tool_use_id: LanguageModelToolUseId,
1966        /// The pending tool use that corresponds to this tool.
1967        pending_tool_use: Option<PendingToolUse>,
1968    },
1969    CheckpointChanged,
1970    ToolConfirmationNeeded,
1971}
1972
1973impl EventEmitter<ThreadEvent> for Thread {}
1974
1975struct PendingCompletion {
1976    id: usize,
1977    _task: Task<()>,
1978}
1979
1980#[cfg(test)]
1981mod tests {
1982    use super::*;
1983    use crate::{ThreadStore, context_store::ContextStore, thread_store};
1984    use assistant_settings::AssistantSettings;
1985    use context_server::ContextServerSettings;
1986    use editor::EditorSettings;
1987    use gpui::TestAppContext;
1988    use project::{FakeFs, Project};
1989    use prompt_store::PromptBuilder;
1990    use serde_json::json;
1991    use settings::{Settings, SettingsStore};
1992    use std::sync::Arc;
1993    use theme::ThemeSettings;
1994    use util::path;
1995    use workspace::Workspace;
1996
1997    #[gpui::test]
1998    async fn test_message_with_context(cx: &mut TestAppContext) {
1999        init_test_settings(cx);
2000
2001        let project = create_test_project(
2002            cx,
2003            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2004        )
2005        .await;
2006
2007        let (_workspace, _thread_store, thread, context_store) =
2008            setup_test_environment(cx, project.clone()).await;
2009
2010        add_file_to_context(&project, &context_store, "test/code.rs", cx)
2011            .await
2012            .unwrap();
2013
2014        let context =
2015            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2016
2017        // Insert user message with context
2018        let message_id = thread.update(cx, |thread, cx| {
2019            thread.insert_user_message("Please explain this code", vec![context], None, cx)
2020        });
2021
2022        // Check content and context in message object
2023        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2024
2025        // Use different path format strings based on platform for the test
2026        #[cfg(windows)]
2027        let path_part = r"test\code.rs";
2028        #[cfg(not(windows))]
2029        let path_part = "test/code.rs";
2030
2031        let expected_context = format!(
2032            r#"
2033<context>
2034The following items were attached by the user. You don't need to use other tools to read them.
2035
2036<files>
2037```rs {path_part}
2038fn main() {{
2039    println!("Hello, world!");
2040}}
2041```
2042</files>
2043</context>
2044"#
2045        );
2046
2047        assert_eq!(message.role, Role::User);
2048        assert_eq!(message.segments.len(), 1);
2049        assert_eq!(
2050            message.segments[0],
2051            MessageSegment::Text("Please explain this code".to_string())
2052        );
2053        assert_eq!(message.context, expected_context);
2054
2055        // Check message in request
2056        let request = thread.read_with(cx, |thread, cx| {
2057            thread.to_completion_request(RequestKind::Chat, cx)
2058        });
2059
2060        assert_eq!(request.messages.len(), 2);
2061        let expected_full_message = format!("{}Please explain this code", expected_context);
2062        assert_eq!(request.messages[1].string_contents(), expected_full_message);
2063    }
2064
2065    #[gpui::test]
2066    async fn test_only_include_new_contexts(cx: &mut TestAppContext) {
2067        init_test_settings(cx);
2068
2069        let project = create_test_project(
2070            cx,
2071            json!({
2072                "file1.rs": "fn function1() {}\n",
2073                "file2.rs": "fn function2() {}\n",
2074                "file3.rs": "fn function3() {}\n",
2075            }),
2076        )
2077        .await;
2078
2079        let (_, _thread_store, thread, context_store) =
2080            setup_test_environment(cx, project.clone()).await;
2081
2082        // Open files individually
2083        add_file_to_context(&project, &context_store, "test/file1.rs", cx)
2084            .await
2085            .unwrap();
2086        add_file_to_context(&project, &context_store, "test/file2.rs", cx)
2087            .await
2088            .unwrap();
2089        add_file_to_context(&project, &context_store, "test/file3.rs", cx)
2090            .await
2091            .unwrap();
2092
2093        // Get the context objects
2094        let contexts = context_store.update(cx, |store, _| store.context().clone());
2095        assert_eq!(contexts.len(), 3);
2096
2097        // First message with context 1
2098        let message1_id = thread.update(cx, |thread, cx| {
2099            thread.insert_user_message("Message 1", vec![contexts[0].clone()], None, cx)
2100        });
2101
2102        // Second message with contexts 1 and 2 (context 1 should be skipped as it's already included)
2103        let message2_id = thread.update(cx, |thread, cx| {
2104            thread.insert_user_message(
2105                "Message 2",
2106                vec![contexts[0].clone(), contexts[1].clone()],
2107                None,
2108                cx,
2109            )
2110        });
2111
2112        // Third message with all three contexts (contexts 1 and 2 should be skipped)
2113        let message3_id = thread.update(cx, |thread, cx| {
2114            thread.insert_user_message(
2115                "Message 3",
2116                vec![
2117                    contexts[0].clone(),
2118                    contexts[1].clone(),
2119                    contexts[2].clone(),
2120                ],
2121                None,
2122                cx,
2123            )
2124        });
2125
2126        // Check what contexts are included in each message
2127        let (message1, message2, message3) = thread.read_with(cx, |thread, _| {
2128            (
2129                thread.message(message1_id).unwrap().clone(),
2130                thread.message(message2_id).unwrap().clone(),
2131                thread.message(message3_id).unwrap().clone(),
2132            )
2133        });
2134
2135        // First message should include context 1
2136        assert!(message1.context.contains("file1.rs"));
2137
2138        // Second message should include only context 2 (not 1)
2139        assert!(!message2.context.contains("file1.rs"));
2140        assert!(message2.context.contains("file2.rs"));
2141
2142        // Third message should include only context 3 (not 1 or 2)
2143        assert!(!message3.context.contains("file1.rs"));
2144        assert!(!message3.context.contains("file2.rs"));
2145        assert!(message3.context.contains("file3.rs"));
2146
2147        // Check entire request to make sure all contexts are properly included
2148        let request = thread.read_with(cx, |thread, cx| {
2149            thread.to_completion_request(RequestKind::Chat, cx)
2150        });
2151
2152        // The request should contain all 3 messages
2153        assert_eq!(request.messages.len(), 4);
2154
2155        // Check that the contexts are properly formatted in each message
2156        assert!(request.messages[1].string_contents().contains("file1.rs"));
2157        assert!(!request.messages[1].string_contents().contains("file2.rs"));
2158        assert!(!request.messages[1].string_contents().contains("file3.rs"));
2159
2160        assert!(!request.messages[2].string_contents().contains("file1.rs"));
2161        assert!(request.messages[2].string_contents().contains("file2.rs"));
2162        assert!(!request.messages[2].string_contents().contains("file3.rs"));
2163
2164        assert!(!request.messages[3].string_contents().contains("file1.rs"));
2165        assert!(!request.messages[3].string_contents().contains("file2.rs"));
2166        assert!(request.messages[3].string_contents().contains("file3.rs"));
2167    }
2168
2169    #[gpui::test]
2170    async fn test_message_without_files(cx: &mut TestAppContext) {
2171        init_test_settings(cx);
2172
2173        let project = create_test_project(
2174            cx,
2175            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2176        )
2177        .await;
2178
2179        let (_, _thread_store, thread, _context_store) =
2180            setup_test_environment(cx, project.clone()).await;
2181
2182        // Insert user message without any context (empty context vector)
2183        let message_id = thread.update(cx, |thread, cx| {
2184            thread.insert_user_message("What is the best way to learn Rust?", vec![], None, cx)
2185        });
2186
2187        // Check content and context in message object
2188        let message = thread.read_with(cx, |thread, _| thread.message(message_id).unwrap().clone());
2189
2190        // Context should be empty when no files are included
2191        assert_eq!(message.role, Role::User);
2192        assert_eq!(message.segments.len(), 1);
2193        assert_eq!(
2194            message.segments[0],
2195            MessageSegment::Text("What is the best way to learn Rust?".to_string())
2196        );
2197        assert_eq!(message.context, "");
2198
2199        // Check message in request
2200        let request = thread.read_with(cx, |thread, cx| {
2201            thread.to_completion_request(RequestKind::Chat, cx)
2202        });
2203
2204        assert_eq!(request.messages.len(), 2);
2205        assert_eq!(
2206            request.messages[1].string_contents(),
2207            "What is the best way to learn Rust?"
2208        );
2209
2210        // Add second message, also without context
2211        let message2_id = thread.update(cx, |thread, cx| {
2212            thread.insert_user_message("Are there any good books?", vec![], None, cx)
2213        });
2214
2215        let message2 =
2216            thread.read_with(cx, |thread, _| thread.message(message2_id).unwrap().clone());
2217        assert_eq!(message2.context, "");
2218
2219        // Check that both messages appear in the request
2220        let request = thread.read_with(cx, |thread, cx| {
2221            thread.to_completion_request(RequestKind::Chat, cx)
2222        });
2223
2224        assert_eq!(request.messages.len(), 3);
2225        assert_eq!(
2226            request.messages[1].string_contents(),
2227            "What is the best way to learn Rust?"
2228        );
2229        assert_eq!(
2230            request.messages[2].string_contents(),
2231            "Are there any good books?"
2232        );
2233    }
2234
2235    #[gpui::test]
2236    async fn test_stale_buffer_notification(cx: &mut TestAppContext) {
2237        init_test_settings(cx);
2238
2239        let project = create_test_project(
2240            cx,
2241            json!({"code.rs": "fn main() {\n    println!(\"Hello, world!\");\n}"}),
2242        )
2243        .await;
2244
2245        let (_workspace, _thread_store, thread, context_store) =
2246            setup_test_environment(cx, project.clone()).await;
2247
2248        // Open buffer and add it to context
2249        let buffer = add_file_to_context(&project, &context_store, "test/code.rs", cx)
2250            .await
2251            .unwrap();
2252
2253        let context =
2254            context_store.update(cx, |store, _| store.context().first().cloned().unwrap());
2255
2256        // Insert user message with the buffer as context
2257        thread.update(cx, |thread, cx| {
2258            thread.insert_user_message("Explain this code", vec![context], None, cx)
2259        });
2260
2261        // Create a request and check that it doesn't have a stale buffer warning yet
2262        let initial_request = thread.read_with(cx, |thread, cx| {
2263            thread.to_completion_request(RequestKind::Chat, cx)
2264        });
2265
2266        // Make sure we don't have a stale file warning yet
2267        let has_stale_warning = initial_request.messages.iter().any(|msg| {
2268            msg.string_contents()
2269                .contains("These files changed since last read:")
2270        });
2271        assert!(
2272            !has_stale_warning,
2273            "Should not have stale buffer warning before buffer is modified"
2274        );
2275
2276        // Modify the buffer
2277        buffer.update(cx, |buffer, cx| {
2278            // Find a position at the end of line 1
2279            buffer.edit(
2280                [(1..1, "\n    println!(\"Added a new line\");\n")],
2281                None,
2282                cx,
2283            );
2284        });
2285
2286        // Insert another user message without context
2287        thread.update(cx, |thread, cx| {
2288            thread.insert_user_message("What does the code do now?", vec![], None, cx)
2289        });
2290
2291        // Create a new request and check for the stale buffer warning
2292        let new_request = thread.read_with(cx, |thread, cx| {
2293            thread.to_completion_request(RequestKind::Chat, cx)
2294        });
2295
2296        // We should have a stale file warning as the last message
2297        let last_message = new_request
2298            .messages
2299            .last()
2300            .expect("Request should have messages");
2301
2302        // The last message should be the stale buffer notification
2303        assert_eq!(last_message.role, Role::User);
2304
2305        // Check the exact content of the message
2306        let expected_content = "These files changed since last read:\n- code.rs\n";
2307        assert_eq!(
2308            last_message.string_contents(),
2309            expected_content,
2310            "Last message should be exactly the stale buffer notification"
2311        );
2312    }
2313
2314    fn init_test_settings(cx: &mut TestAppContext) {
2315        cx.update(|cx| {
2316            let settings_store = SettingsStore::test(cx);
2317            cx.set_global(settings_store);
2318            language::init(cx);
2319            Project::init_settings(cx);
2320            AssistantSettings::register(cx);
2321            thread_store::init(cx);
2322            workspace::init_settings(cx);
2323            ThemeSettings::register(cx);
2324            ContextServerSettings::register(cx);
2325            EditorSettings::register(cx);
2326        });
2327    }
2328
2329    // Helper to create a test project with test files
2330    async fn create_test_project(
2331        cx: &mut TestAppContext,
2332        files: serde_json::Value,
2333    ) -> Entity<Project> {
2334        let fs = FakeFs::new(cx.executor());
2335        fs.insert_tree(path!("/test"), files).await;
2336        Project::test(fs, [path!("/test").as_ref()], cx).await
2337    }
2338
2339    async fn setup_test_environment(
2340        cx: &mut TestAppContext,
2341        project: Entity<Project>,
2342    ) -> (
2343        Entity<Workspace>,
2344        Entity<ThreadStore>,
2345        Entity<Thread>,
2346        Entity<ContextStore>,
2347    ) {
2348        let (workspace, cx) =
2349            cx.add_window_view(|window, cx| Workspace::test_new(project.clone(), window, cx));
2350
2351        let thread_store = cx
2352            .update(|_, cx| {
2353                ThreadStore::load(
2354                    project.clone(),
2355                    cx.new(|_| ToolWorkingSet::default()),
2356                    Arc::new(PromptBuilder::new(None).unwrap()),
2357                    cx,
2358                )
2359            })
2360            .await;
2361
2362        let thread = thread_store.update(cx, |store, cx| store.create_thread(cx));
2363        let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None));
2364
2365        (workspace, thread_store, thread, context_store)
2366    }
2367
2368    async fn add_file_to_context(
2369        project: &Entity<Project>,
2370        context_store: &Entity<ContextStore>,
2371        path: &str,
2372        cx: &mut TestAppContext,
2373    ) -> Result<Entity<language::Buffer>> {
2374        let buffer_path = project
2375            .read_with(cx, |project, cx| project.find_project_path(path, cx))
2376            .unwrap();
2377
2378        let buffer = project
2379            .update(cx, |project, cx| project.open_buffer(buffer_path, cx))
2380            .await
2381            .unwrap();
2382
2383        context_store
2384            .update(cx, |store, cx| {
2385                store.add_file_from_buffer(buffer.clone(), cx)
2386            })
2387            .await?;
2388
2389        Ok(buffer)
2390    }
2391}