thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result, anyhow};
   7use assistant_settings::AssistantSettings;
   8use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::{BTreeMap, HashMap, HashSet};
  11use fs::Fs;
  12use futures::future::Shared;
  13use futures::{FutureExt, StreamExt as _};
  14use git::repository::DiffType;
  15use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  18    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  19    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  20    Role, StopReason, TokenUsage,
  21};
  22use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  23use project::{Project, Worktree};
  24use prompt_store::{
  25    AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
  26};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
  31use uuid::Uuid;
  32
  33use crate::context::{AssistantContext, ContextId, attach_context_to_message};
  34use crate::thread_store::{
  35    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  36    SerializedToolUse,
  37};
  38use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState, USING_TOOL_MARKER};
  39
  40#[derive(Debug, Clone, Copy)]
  41pub enum RequestKind {
  42    Chat,
  43    /// Used when summarizing a thread.
  44    Summarize,
  45}
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  49)]
  50pub struct ThreadId(Arc<str>);
  51
  52impl ThreadId {
  53    pub fn new() -> Self {
  54        Self(Uuid::new_v4().to_string().into())
  55    }
  56}
  57
  58impl std::fmt::Display for ThreadId {
  59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  60        write!(f, "{}", self.0)
  61    }
  62}
  63
  64impl From<&str> for ThreadId {
  65    fn from(value: &str) -> Self {
  66        Self(value.into())
  67    }
  68}
  69
  70#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  71pub struct MessageId(pub(crate) usize);
  72
  73impl MessageId {
  74    fn post_inc(&mut self) -> Self {
  75        Self(post_inc(&mut self.0))
  76    }
  77}
  78
  79/// A message in a [`Thread`].
  80#[derive(Debug, Clone)]
  81pub struct Message {
  82    pub id: MessageId,
  83    pub role: Role,
  84    pub segments: Vec<MessageSegment>,
  85}
  86
  87impl Message {
  88    /// Returns whether the message contains any meaningful text that should be displayed
  89    /// The model sometimes runs tool without producing any text or just a marker ([`USING_TOOL_MARKER`])
  90    pub fn should_display_content(&self) -> bool {
  91        self.segments.iter().all(|segment| segment.should_display())
  92    }
  93
  94    pub fn push_thinking(&mut self, text: &str) {
  95        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  96            segment.push_str(text);
  97        } else {
  98            self.segments
  99                .push(MessageSegment::Thinking(text.to_string()));
 100        }
 101    }
 102
 103    pub fn push_text(&mut self, text: &str) {
 104        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
 105            segment.push_str(text);
 106        } else {
 107            self.segments.push(MessageSegment::Text(text.to_string()));
 108        }
 109    }
 110
 111    pub fn to_string(&self) -> String {
 112        let mut result = String::new();
 113        for segment in &self.segments {
 114            match segment {
 115                MessageSegment::Text(text) => result.push_str(text),
 116                MessageSegment::Thinking(text) => {
 117                    result.push_str("<think>");
 118                    result.push_str(text);
 119                    result.push_str("</think>");
 120                }
 121            }
 122        }
 123        result
 124    }
 125}
 126
 127#[derive(Debug, Clone)]
 128pub enum MessageSegment {
 129    Text(String),
 130    Thinking(String),
 131}
 132
 133impl MessageSegment {
 134    pub fn text_mut(&mut self) -> &mut String {
 135        match self {
 136            Self::Text(text) => text,
 137            Self::Thinking(text) => text,
 138        }
 139    }
 140
 141    pub fn should_display(&self) -> bool {
 142        // We add USING_TOOL_MARKER when making a request that includes tool uses
 143        // without non-whitespace text around them, and this can cause the model
 144        // to mimic the pattern, so we consider those segments not displayable.
 145        match self {
 146            Self::Text(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 147            Self::Thinking(text) => text.is_empty() || text.trim() == USING_TOOL_MARKER,
 148        }
 149    }
 150}
 151
 152#[derive(Debug, Clone, Serialize, Deserialize)]
 153pub struct ProjectSnapshot {
 154    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 155    pub unsaved_buffer_paths: Vec<String>,
 156    pub timestamp: DateTime<Utc>,
 157}
 158
 159#[derive(Debug, Clone, Serialize, Deserialize)]
 160pub struct WorktreeSnapshot {
 161    pub worktree_path: String,
 162    pub git_state: Option<GitState>,
 163}
 164
 165#[derive(Debug, Clone, Serialize, Deserialize)]
 166pub struct GitState {
 167    pub remote_url: Option<String>,
 168    pub head_sha: Option<String>,
 169    pub current_branch: Option<String>,
 170    pub diff: Option<String>,
 171}
 172
 173#[derive(Clone)]
 174pub struct ThreadCheckpoint {
 175    message_id: MessageId,
 176    git_checkpoint: GitStoreCheckpoint,
 177}
 178
 179#[derive(Copy, Clone, Debug)]
 180pub enum ThreadFeedback {
 181    Positive,
 182    Negative,
 183}
 184
 185pub enum LastRestoreCheckpoint {
 186    Pending {
 187        message_id: MessageId,
 188    },
 189    Error {
 190        message_id: MessageId,
 191        error: String,
 192    },
 193}
 194
 195impl LastRestoreCheckpoint {
 196    pub fn message_id(&self) -> MessageId {
 197        match self {
 198            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 199            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 200        }
 201    }
 202}
 203
 204#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 205pub enum DetailedSummaryState {
 206    #[default]
 207    NotGenerated,
 208    Generating {
 209        message_id: MessageId,
 210    },
 211    Generated {
 212        text: SharedString,
 213        message_id: MessageId,
 214    },
 215}
 216
 217#[derive(Default)]
 218pub struct TotalTokenUsage {
 219    pub total: usize,
 220    pub max: usize,
 221    pub ratio: TokenUsageRatio,
 222}
 223
 224#[derive(Default, PartialEq, Eq)]
 225pub enum TokenUsageRatio {
 226    #[default]
 227    Normal,
 228    Warning,
 229    Exceeded,
 230}
 231
 232/// A thread of conversation with the LLM.
 233pub struct Thread {
 234    id: ThreadId,
 235    updated_at: DateTime<Utc>,
 236    summary: Option<SharedString>,
 237    pending_summary: Task<Option<()>>,
 238    detailed_summary_state: DetailedSummaryState,
 239    messages: Vec<Message>,
 240    next_message_id: MessageId,
 241    context: BTreeMap<ContextId, AssistantContext>,
 242    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 243    system_prompt_context: Option<AssistantSystemPromptContext>,
 244    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 245    completion_count: usize,
 246    pending_completions: Vec<PendingCompletion>,
 247    project: Entity<Project>,
 248    prompt_builder: Arc<PromptBuilder>,
 249    tools: Arc<ToolWorkingSet>,
 250    tool_use: ToolUseState,
 251    action_log: Entity<ActionLog>,
 252    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 253    pending_checkpoint: Option<ThreadCheckpoint>,
 254    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 255    cumulative_token_usage: TokenUsage,
 256    feedback: Option<ThreadFeedback>,
 257}
 258
 259impl Thread {
 260    pub fn new(
 261        project: Entity<Project>,
 262        tools: Arc<ToolWorkingSet>,
 263        prompt_builder: Arc<PromptBuilder>,
 264        cx: &mut Context<Self>,
 265    ) -> Self {
 266        Self {
 267            id: ThreadId::new(),
 268            updated_at: Utc::now(),
 269            summary: None,
 270            pending_summary: Task::ready(None),
 271            detailed_summary_state: DetailedSummaryState::NotGenerated,
 272            messages: Vec::new(),
 273            next_message_id: MessageId(0),
 274            context: BTreeMap::default(),
 275            context_by_message: HashMap::default(),
 276            system_prompt_context: None,
 277            checkpoints_by_message: HashMap::default(),
 278            completion_count: 0,
 279            pending_completions: Vec::new(),
 280            project: project.clone(),
 281            prompt_builder,
 282            tools: tools.clone(),
 283            last_restore_checkpoint: None,
 284            pending_checkpoint: None,
 285            tool_use: ToolUseState::new(tools.clone()),
 286            action_log: cx.new(|_| ActionLog::new()),
 287            initial_project_snapshot: {
 288                let project_snapshot = Self::project_snapshot(project, cx);
 289                cx.foreground_executor()
 290                    .spawn(async move { Some(project_snapshot.await) })
 291                    .shared()
 292            },
 293            cumulative_token_usage: TokenUsage::default(),
 294            feedback: None,
 295        }
 296    }
 297
 298    pub fn deserialize(
 299        id: ThreadId,
 300        serialized: SerializedThread,
 301        project: Entity<Project>,
 302        tools: Arc<ToolWorkingSet>,
 303        prompt_builder: Arc<PromptBuilder>,
 304        cx: &mut Context<Self>,
 305    ) -> Self {
 306        let next_message_id = MessageId(
 307            serialized
 308                .messages
 309                .last()
 310                .map(|message| message.id.0 + 1)
 311                .unwrap_or(0),
 312        );
 313        let tool_use =
 314            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 315
 316        Self {
 317            id,
 318            updated_at: serialized.updated_at,
 319            summary: Some(serialized.summary),
 320            pending_summary: Task::ready(None),
 321            detailed_summary_state: serialized.detailed_summary_state,
 322            messages: serialized
 323                .messages
 324                .into_iter()
 325                .map(|message| Message {
 326                    id: message.id,
 327                    role: message.role,
 328                    segments: message
 329                        .segments
 330                        .into_iter()
 331                        .map(|segment| match segment {
 332                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 333                            SerializedMessageSegment::Thinking { text } => {
 334                                MessageSegment::Thinking(text)
 335                            }
 336                        })
 337                        .collect(),
 338                })
 339                .collect(),
 340            next_message_id,
 341            context: BTreeMap::default(),
 342            context_by_message: HashMap::default(),
 343            system_prompt_context: None,
 344            checkpoints_by_message: HashMap::default(),
 345            completion_count: 0,
 346            pending_completions: Vec::new(),
 347            last_restore_checkpoint: None,
 348            pending_checkpoint: None,
 349            project,
 350            prompt_builder,
 351            tools,
 352            tool_use,
 353            action_log: cx.new(|_| ActionLog::new()),
 354            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 355            cumulative_token_usage: serialized.cumulative_token_usage,
 356            feedback: None,
 357        }
 358    }
 359
 360    pub fn id(&self) -> &ThreadId {
 361        &self.id
 362    }
 363
 364    pub fn is_empty(&self) -> bool {
 365        self.messages.is_empty()
 366    }
 367
 368    pub fn updated_at(&self) -> DateTime<Utc> {
 369        self.updated_at
 370    }
 371
 372    pub fn touch_updated_at(&mut self) {
 373        self.updated_at = Utc::now();
 374    }
 375
 376    pub fn summary(&self) -> Option<SharedString> {
 377        self.summary.clone()
 378    }
 379
 380    pub fn summary_or_default(&self) -> SharedString {
 381        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 382        self.summary.clone().unwrap_or(DEFAULT)
 383    }
 384
 385    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 386        self.summary = Some(summary.into());
 387        cx.emit(ThreadEvent::SummaryChanged);
 388    }
 389
 390    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 391        self.latest_detailed_summary()
 392            .unwrap_or_else(|| self.text().into())
 393    }
 394
 395    fn latest_detailed_summary(&self) -> Option<SharedString> {
 396        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 397            Some(text.clone())
 398        } else {
 399            None
 400        }
 401    }
 402
 403    pub fn message(&self, id: MessageId) -> Option<&Message> {
 404        self.messages.iter().find(|message| message.id == id)
 405    }
 406
 407    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 408        self.messages.iter()
 409    }
 410
 411    pub fn is_generating(&self) -> bool {
 412        !self.pending_completions.is_empty() || !self.all_tools_finished()
 413    }
 414
 415    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 416        &self.tools
 417    }
 418
 419    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 420        self.tool_use
 421            .pending_tool_uses()
 422            .into_iter()
 423            .find(|tool_use| &tool_use.id == id)
 424    }
 425
 426    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 427        self.tool_use
 428            .pending_tool_uses()
 429            .into_iter()
 430            .filter(|tool_use| tool_use.status.needs_confirmation())
 431    }
 432
 433    pub fn has_pending_tool_uses(&self) -> bool {
 434        !self.tool_use.pending_tool_uses().is_empty()
 435    }
 436
 437    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 438        self.checkpoints_by_message.get(&id).cloned()
 439    }
 440
 441    pub fn restore_checkpoint(
 442        &mut self,
 443        checkpoint: ThreadCheckpoint,
 444        cx: &mut Context<Self>,
 445    ) -> Task<Result<()>> {
 446        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 447            message_id: checkpoint.message_id,
 448        });
 449        cx.emit(ThreadEvent::CheckpointChanged);
 450        cx.notify();
 451
 452        let project = self.project.read(cx);
 453        let restore = project
 454            .git_store()
 455            .read(cx)
 456            .restore_checkpoint(checkpoint.git_checkpoint.clone(), cx);
 457        cx.spawn(async move |this, cx| {
 458            let result = restore.await;
 459            this.update(cx, |this, cx| {
 460                if let Err(err) = result.as_ref() {
 461                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 462                        message_id: checkpoint.message_id,
 463                        error: err.to_string(),
 464                    });
 465                } else {
 466                    this.truncate(checkpoint.message_id, cx);
 467                    this.last_restore_checkpoint = None;
 468                }
 469                this.pending_checkpoint = None;
 470                cx.emit(ThreadEvent::CheckpointChanged);
 471                cx.notify();
 472            })?;
 473            result
 474        })
 475    }
 476
 477    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 478        let pending_checkpoint = if self.is_generating() {
 479            return;
 480        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 481            checkpoint
 482        } else {
 483            return;
 484        };
 485
 486        let git_store = self.project.read(cx).git_store().clone();
 487        let final_checkpoint = git_store.read(cx).checkpoint(cx);
 488        cx.spawn(async move |this, cx| match final_checkpoint.await {
 489            Ok(final_checkpoint) => {
 490                let equal = git_store
 491                    .read_with(cx, |store, cx| {
 492                        store.compare_checkpoints(
 493                            pending_checkpoint.git_checkpoint.clone(),
 494                            final_checkpoint.clone(),
 495                            cx,
 496                        )
 497                    })?
 498                    .await
 499                    .unwrap_or(false);
 500
 501                if equal {
 502                    git_store
 503                        .read_with(cx, |store, cx| {
 504                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 505                        })?
 506                        .detach();
 507                } else {
 508                    this.update(cx, |this, cx| {
 509                        this.insert_checkpoint(pending_checkpoint, cx)
 510                    })?;
 511                }
 512
 513                git_store
 514                    .read_with(cx, |store, cx| {
 515                        store.delete_checkpoint(final_checkpoint, cx)
 516                    })?
 517                    .detach();
 518
 519                Ok(())
 520            }
 521            Err(_) => this.update(cx, |this, cx| {
 522                this.insert_checkpoint(pending_checkpoint, cx)
 523            }),
 524        })
 525        .detach();
 526    }
 527
 528    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 529        self.checkpoints_by_message
 530            .insert(checkpoint.message_id, checkpoint);
 531        cx.emit(ThreadEvent::CheckpointChanged);
 532        cx.notify();
 533    }
 534
 535    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 536        self.last_restore_checkpoint.as_ref()
 537    }
 538
 539    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 540        let Some(message_ix) = self
 541            .messages
 542            .iter()
 543            .rposition(|message| message.id == message_id)
 544        else {
 545            return;
 546        };
 547        for deleted_message in self.messages.drain(message_ix..) {
 548            self.context_by_message.remove(&deleted_message.id);
 549            self.checkpoints_by_message.remove(&deleted_message.id);
 550        }
 551        cx.notify();
 552    }
 553
 554    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 555        self.context_by_message
 556            .get(&id)
 557            .into_iter()
 558            .flat_map(|context| {
 559                context
 560                    .iter()
 561                    .filter_map(|context_id| self.context.get(&context_id))
 562            })
 563    }
 564
 565    /// Returns whether all of the tool uses have finished running.
 566    pub fn all_tools_finished(&self) -> bool {
 567        // If the only pending tool uses left are the ones with errors, then
 568        // that means that we've finished running all of the pending tools.
 569        self.tool_use
 570            .pending_tool_uses()
 571            .iter()
 572            .all(|tool_use| tool_use.status.is_error())
 573    }
 574
 575    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 576        self.tool_use.tool_uses_for_message(id, cx)
 577    }
 578
 579    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 580        self.tool_use.tool_results_for_message(id)
 581    }
 582
 583    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 584        self.tool_use.tool_result(id)
 585    }
 586
 587    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 588        self.tool_use.message_has_tool_results(message_id)
 589    }
 590
 591    pub fn insert_user_message(
 592        &mut self,
 593        text: impl Into<String>,
 594        context: Vec<AssistantContext>,
 595        git_checkpoint: Option<GitStoreCheckpoint>,
 596        cx: &mut Context<Self>,
 597    ) -> MessageId {
 598        let message_id =
 599            self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
 600        let context_ids = context
 601            .iter()
 602            .map(|context| context.id())
 603            .collect::<Vec<_>>();
 604        self.context
 605            .extend(context.into_iter().map(|context| (context.id(), context)));
 606        self.context_by_message.insert(message_id, context_ids);
 607        if let Some(git_checkpoint) = git_checkpoint {
 608            self.pending_checkpoint = Some(ThreadCheckpoint {
 609                message_id,
 610                git_checkpoint,
 611            });
 612        }
 613        message_id
 614    }
 615
 616    pub fn insert_message(
 617        &mut self,
 618        role: Role,
 619        segments: Vec<MessageSegment>,
 620        cx: &mut Context<Self>,
 621    ) -> MessageId {
 622        let id = self.next_message_id.post_inc();
 623        self.messages.push(Message { id, role, segments });
 624        self.touch_updated_at();
 625        cx.emit(ThreadEvent::MessageAdded(id));
 626        id
 627    }
 628
 629    pub fn edit_message(
 630        &mut self,
 631        id: MessageId,
 632        new_role: Role,
 633        new_segments: Vec<MessageSegment>,
 634        cx: &mut Context<Self>,
 635    ) -> bool {
 636        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 637            return false;
 638        };
 639        message.role = new_role;
 640        message.segments = new_segments;
 641        self.touch_updated_at();
 642        cx.emit(ThreadEvent::MessageEdited(id));
 643        true
 644    }
 645
 646    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 647        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 648            return false;
 649        };
 650        self.messages.remove(index);
 651        self.context_by_message.remove(&id);
 652        self.touch_updated_at();
 653        cx.emit(ThreadEvent::MessageDeleted(id));
 654        true
 655    }
 656
 657    /// Returns the representation of this [`Thread`] in a textual form.
 658    ///
 659    /// This is the representation we use when attaching a thread as context to another thread.
 660    pub fn text(&self) -> String {
 661        let mut text = String::new();
 662
 663        for message in &self.messages {
 664            text.push_str(match message.role {
 665                language_model::Role::User => "User:",
 666                language_model::Role::Assistant => "Assistant:",
 667                language_model::Role::System => "System:",
 668            });
 669            text.push('\n');
 670
 671            for segment in &message.segments {
 672                match segment {
 673                    MessageSegment::Text(content) => text.push_str(content),
 674                    MessageSegment::Thinking(content) => {
 675                        text.push_str(&format!("<think>{}</think>", content))
 676                    }
 677                }
 678            }
 679            text.push('\n');
 680        }
 681
 682        text
 683    }
 684
 685    /// Serializes this thread into a format for storage or telemetry.
 686    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 687        let initial_project_snapshot = self.initial_project_snapshot.clone();
 688        cx.spawn(async move |this, cx| {
 689            let initial_project_snapshot = initial_project_snapshot.await;
 690            this.read_with(cx, |this, cx| SerializedThread {
 691                version: SerializedThread::VERSION.to_string(),
 692                summary: this.summary_or_default(),
 693                updated_at: this.updated_at(),
 694                messages: this
 695                    .messages()
 696                    .map(|message| SerializedMessage {
 697                        id: message.id,
 698                        role: message.role,
 699                        segments: message
 700                            .segments
 701                            .iter()
 702                            .map(|segment| match segment {
 703                                MessageSegment::Text(text) => {
 704                                    SerializedMessageSegment::Text { text: text.clone() }
 705                                }
 706                                MessageSegment::Thinking(text) => {
 707                                    SerializedMessageSegment::Thinking { text: text.clone() }
 708                                }
 709                            })
 710                            .collect(),
 711                        tool_uses: this
 712                            .tool_uses_for_message(message.id, cx)
 713                            .into_iter()
 714                            .map(|tool_use| SerializedToolUse {
 715                                id: tool_use.id,
 716                                name: tool_use.name,
 717                                input: tool_use.input,
 718                            })
 719                            .collect(),
 720                        tool_results: this
 721                            .tool_results_for_message(message.id)
 722                            .into_iter()
 723                            .map(|tool_result| SerializedToolResult {
 724                                tool_use_id: tool_result.tool_use_id.clone(),
 725                                is_error: tool_result.is_error,
 726                                content: tool_result.content.clone(),
 727                            })
 728                            .collect(),
 729                    })
 730                    .collect(),
 731                initial_project_snapshot,
 732                cumulative_token_usage: this.cumulative_token_usage.clone(),
 733                detailed_summary_state: this.detailed_summary_state.clone(),
 734            })
 735        })
 736    }
 737
 738    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 739        self.system_prompt_context = Some(context);
 740    }
 741
 742    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 743        &self.system_prompt_context
 744    }
 745
 746    pub fn load_system_prompt_context(
 747        &self,
 748        cx: &App,
 749    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 750        let project = self.project.read(cx);
 751        let tasks = project
 752            .visible_worktrees(cx)
 753            .map(|worktree| {
 754                Self::load_worktree_info_for_system_prompt(
 755                    project.fs().clone(),
 756                    worktree.read(cx),
 757                    cx,
 758                )
 759            })
 760            .collect::<Vec<_>>();
 761
 762        cx.spawn(async |_cx| {
 763            let results = futures::future::join_all(tasks).await;
 764            let mut first_err = None;
 765            let worktrees = results
 766                .into_iter()
 767                .map(|(worktree, err)| {
 768                    if first_err.is_none() && err.is_some() {
 769                        first_err = err;
 770                    }
 771                    worktree
 772                })
 773                .collect::<Vec<_>>();
 774            (AssistantSystemPromptContext::new(worktrees), first_err)
 775        })
 776    }
 777
 778    fn load_worktree_info_for_system_prompt(
 779        fs: Arc<dyn Fs>,
 780        worktree: &Worktree,
 781        cx: &App,
 782    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 783        let root_name = worktree.root_name().into();
 784        let abs_path = worktree.abs_path();
 785
 786        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 787        // supported. This doesn't seem to occur often in GitHub repositories.
 788        const RULES_FILE_NAMES: [&'static str; 6] = [
 789            ".rules",
 790            ".cursorrules",
 791            ".windsurfrules",
 792            ".clinerules",
 793            ".github/copilot-instructions.md",
 794            "CLAUDE.md",
 795        ];
 796        let selected_rules_file = RULES_FILE_NAMES
 797            .into_iter()
 798            .filter_map(|name| {
 799                worktree
 800                    .entry_for_path(name)
 801                    .filter(|entry| entry.is_file())
 802                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
 803            })
 804            .next();
 805
 806        if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
 807            cx.spawn(async move |_| {
 808                let rules_file_result = maybe!(async move {
 809                    let abs_rules_path = abs_rules_path?;
 810                    let text = fs.load(&abs_rules_path).await.with_context(|| {
 811                        format!("Failed to load assistant rules file {:?}", abs_rules_path)
 812                    })?;
 813                    anyhow::Ok(RulesFile {
 814                        rel_path: rel_rules_path,
 815                        abs_path: abs_rules_path.into(),
 816                        text: text.trim().to_string(),
 817                    })
 818                })
 819                .await;
 820                let (rules_file, rules_file_error) = match rules_file_result {
 821                    Ok(rules_file) => (Some(rules_file), None),
 822                    Err(err) => (
 823                        None,
 824                        Some(ThreadError::Message {
 825                            header: "Error loading rules file".into(),
 826                            message: format!("{err}").into(),
 827                        }),
 828                    ),
 829                };
 830                let worktree_info = WorktreeInfoForSystemPrompt {
 831                    root_name,
 832                    abs_path,
 833                    rules_file,
 834                };
 835                (worktree_info, rules_file_error)
 836            })
 837        } else {
 838            Task::ready((
 839                WorktreeInfoForSystemPrompt {
 840                    root_name,
 841                    abs_path,
 842                    rules_file: None,
 843                },
 844                None,
 845            ))
 846        }
 847    }
 848
 849    pub fn send_to_model(
 850        &mut self,
 851        model: Arc<dyn LanguageModel>,
 852        request_kind: RequestKind,
 853        cx: &mut Context<Self>,
 854    ) {
 855        let mut request = self.to_completion_request(request_kind, cx);
 856        if model.supports_tools() {
 857            request.tools = {
 858                let mut tools = Vec::new();
 859                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 860                    LanguageModelRequestTool {
 861                        name: tool.name(),
 862                        description: tool.description(),
 863                        input_schema: tool.input_schema(model.tool_input_format()),
 864                    }
 865                }));
 866
 867                tools
 868            };
 869        }
 870
 871        self.stream_completion(request, model, cx);
 872    }
 873
 874    pub fn used_tools_since_last_user_message(&self) -> bool {
 875        for message in self.messages.iter().rev() {
 876            if self.tool_use.message_has_tool_results(message.id) {
 877                return true;
 878            } else if message.role == Role::User {
 879                return false;
 880            }
 881        }
 882
 883        false
 884    }
 885
 886    pub fn to_completion_request(
 887        &self,
 888        request_kind: RequestKind,
 889        cx: &App,
 890    ) -> LanguageModelRequest {
 891        let mut request = LanguageModelRequest {
 892            messages: vec![],
 893            tools: Vec::new(),
 894            stop: Vec::new(),
 895            temperature: None,
 896        };
 897
 898        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 899            if let Some(system_prompt) = self
 900                .prompt_builder
 901                .generate_assistant_system_prompt(system_prompt_context)
 902                .context("failed to generate assistant system prompt")
 903                .log_err()
 904            {
 905                request.messages.push(LanguageModelRequestMessage {
 906                    role: Role::System,
 907                    content: vec![MessageContent::Text(system_prompt)],
 908                    cache: true,
 909                });
 910            }
 911        } else {
 912            log::error!("system_prompt_context not set.")
 913        }
 914
 915        let mut added_context_ids = HashSet::<ContextId>::default();
 916
 917        for message in &self.messages {
 918            let mut request_message = LanguageModelRequestMessage {
 919                role: message.role,
 920                content: Vec::new(),
 921                cache: false,
 922            };
 923
 924            match request_kind {
 925                RequestKind::Chat => {
 926                    self.tool_use
 927                        .attach_tool_results(message.id, &mut request_message);
 928                }
 929                RequestKind::Summarize => {
 930                    // We don't care about tool use during summarization.
 931                    if self.tool_use.message_has_tool_results(message.id) {
 932                        continue;
 933                    }
 934                }
 935            }
 936
 937            // Attach context to this message if it's the first to reference it
 938            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 939                let new_context_ids: Vec<_> = context_ids
 940                    .iter()
 941                    .filter(|id| !added_context_ids.contains(id))
 942                    .collect();
 943
 944                if !new_context_ids.is_empty() {
 945                    let referenced_context = new_context_ids
 946                        .iter()
 947                        .filter_map(|context_id| self.context.get(*context_id));
 948
 949                    attach_context_to_message(&mut request_message, referenced_context, cx);
 950                    added_context_ids.extend(context_ids.iter());
 951                }
 952            }
 953
 954            if !message.segments.is_empty() {
 955                request_message
 956                    .content
 957                    .push(MessageContent::Text(message.to_string()));
 958            }
 959
 960            match request_kind {
 961                RequestKind::Chat => {
 962                    self.tool_use
 963                        .attach_tool_uses(message.id, &mut request_message);
 964                }
 965                RequestKind::Summarize => {
 966                    // We don't care about tool use during summarization.
 967                }
 968            };
 969
 970            request.messages.push(request_message);
 971        }
 972
 973        // Set a cache breakpoint at the second-to-last message.
 974        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 975        let breakpoint_index = request.messages.len() - 2;
 976        for (index, message) in request.messages.iter_mut().enumerate() {
 977            message.cache = index == breakpoint_index;
 978        }
 979
 980        self.attached_tracked_files_state(&mut request.messages, cx);
 981
 982        request
 983    }
 984
 985    fn attached_tracked_files_state(
 986        &self,
 987        messages: &mut Vec<LanguageModelRequestMessage>,
 988        cx: &App,
 989    ) {
 990        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 991
 992        let mut stale_message = String::new();
 993
 994        let action_log = self.action_log.read(cx);
 995
 996        for stale_file in action_log.stale_buffers(cx) {
 997            let Some(file) = stale_file.read(cx).file() else {
 998                continue;
 999            };
1000
1001            if stale_message.is_empty() {
1002                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
1003            }
1004
1005            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
1006        }
1007
1008        let mut content = Vec::with_capacity(2);
1009
1010        if !stale_message.is_empty() {
1011            content.push(stale_message.into());
1012        }
1013
1014        if action_log.has_edited_files_since_project_diagnostics_check() {
1015            content.push(
1016                "\n\nWhen you're done making changes, make sure to check project diagnostics \
1017                and fix all errors AND warnings you introduced! \
1018                DO NOT mention you're going to do this until you're done."
1019                    .into(),
1020            );
1021        }
1022
1023        if !content.is_empty() {
1024            let context_message = LanguageModelRequestMessage {
1025                role: Role::User,
1026                content,
1027                cache: false,
1028            };
1029
1030            messages.push(context_message);
1031        }
1032    }
1033
1034    pub fn stream_completion(
1035        &mut self,
1036        request: LanguageModelRequest,
1037        model: Arc<dyn LanguageModel>,
1038        cx: &mut Context<Self>,
1039    ) {
1040        let pending_completion_id = post_inc(&mut self.completion_count);
1041
1042        let task = cx.spawn(async move |thread, cx| {
1043            let stream = model.stream_completion(request, &cx);
1044            let initial_token_usage =
1045                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1046            let stream_completion = async {
1047                let mut events = stream.await?;
1048                let mut stop_reason = StopReason::EndTurn;
1049                let mut current_token_usage = TokenUsage::default();
1050
1051                while let Some(event) = events.next().await {
1052                    let event = event?;
1053
1054                    thread.update(cx, |thread, cx| {
1055                        match event {
1056                            LanguageModelCompletionEvent::StartMessage { .. } => {
1057                                thread.insert_message(
1058                                    Role::Assistant,
1059                                    vec![MessageSegment::Text(String::new())],
1060                                    cx,
1061                                );
1062                            }
1063                            LanguageModelCompletionEvent::Stop(reason) => {
1064                                stop_reason = reason;
1065                            }
1066                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1067                                thread.cumulative_token_usage =
1068                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1069                                        - current_token_usage.clone();
1070                                current_token_usage = token_usage;
1071                            }
1072                            LanguageModelCompletionEvent::Text(chunk) => {
1073                                if let Some(last_message) = thread.messages.last_mut() {
1074                                    if last_message.role == Role::Assistant {
1075                                        last_message.push_text(&chunk);
1076                                        cx.emit(ThreadEvent::StreamedAssistantText(
1077                                            last_message.id,
1078                                            chunk,
1079                                        ));
1080                                    } else {
1081                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1082                                        // of a new Assistant response.
1083                                        //
1084                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1085                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1086                                        thread.insert_message(
1087                                            Role::Assistant,
1088                                            vec![MessageSegment::Text(chunk.to_string())],
1089                                            cx,
1090                                        );
1091                                    };
1092                                }
1093                            }
1094                            LanguageModelCompletionEvent::Thinking(chunk) => {
1095                                if let Some(last_message) = thread.messages.last_mut() {
1096                                    if last_message.role == Role::Assistant {
1097                                        last_message.push_thinking(&chunk);
1098                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1099                                            last_message.id,
1100                                            chunk,
1101                                        ));
1102                                    } else {
1103                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1104                                        // of a new Assistant response.
1105                                        //
1106                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1107                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1108                                        thread.insert_message(
1109                                            Role::Assistant,
1110                                            vec![MessageSegment::Thinking(chunk.to_string())],
1111                                            cx,
1112                                        );
1113                                    };
1114                                }
1115                            }
1116                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1117                                let last_assistant_message_id = thread
1118                                    .messages
1119                                    .iter_mut()
1120                                    .rfind(|message| message.role == Role::Assistant)
1121                                    .map(|message| message.id)
1122                                    .unwrap_or_else(|| {
1123                                        thread.insert_message(Role::Assistant, vec![], cx)
1124                                    });
1125
1126                                thread.tool_use.request_tool_use(
1127                                    last_assistant_message_id,
1128                                    tool_use,
1129                                    cx,
1130                                );
1131                            }
1132                        }
1133
1134                        thread.touch_updated_at();
1135                        cx.emit(ThreadEvent::StreamedCompletion);
1136                        cx.notify();
1137                    })?;
1138
1139                    smol::future::yield_now().await;
1140                }
1141
1142                thread.update(cx, |thread, cx| {
1143                    thread
1144                        .pending_completions
1145                        .retain(|completion| completion.id != pending_completion_id);
1146
1147                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1148                        thread.summarize(cx);
1149                    }
1150                })?;
1151
1152                anyhow::Ok(stop_reason)
1153            };
1154
1155            let result = stream_completion.await;
1156
1157            thread
1158                .update(cx, |thread, cx| {
1159                    thread.finalize_pending_checkpoint(cx);
1160                    match result.as_ref() {
1161                        Ok(stop_reason) => match stop_reason {
1162                            StopReason::ToolUse => {
1163                                cx.emit(ThreadEvent::UsePendingTools);
1164                            }
1165                            StopReason::EndTurn => {}
1166                            StopReason::MaxTokens => {}
1167                        },
1168                        Err(error) => {
1169                            if error.is::<PaymentRequiredError>() {
1170                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1171                            } else if error.is::<MaxMonthlySpendReachedError>() {
1172                                cx.emit(ThreadEvent::ShowError(
1173                                    ThreadError::MaxMonthlySpendReached,
1174                                ));
1175                            } else {
1176                                let error_message = error
1177                                    .chain()
1178                                    .map(|err| err.to_string())
1179                                    .collect::<Vec<_>>()
1180                                    .join("\n");
1181                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1182                                    header: "Error interacting with language model".into(),
1183                                    message: SharedString::from(error_message.clone()),
1184                                }));
1185                            }
1186
1187                            thread.cancel_last_completion(cx);
1188                        }
1189                    }
1190                    cx.emit(ThreadEvent::DoneStreaming);
1191
1192                    if let Ok(initial_usage) = initial_token_usage {
1193                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1194
1195                        telemetry::event!(
1196                            "Assistant Thread Completion",
1197                            thread_id = thread.id().to_string(),
1198                            model = model.telemetry_id(),
1199                            model_provider = model.provider_id().to_string(),
1200                            input_tokens = usage.input_tokens,
1201                            output_tokens = usage.output_tokens,
1202                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1203                            cache_read_input_tokens = usage.cache_read_input_tokens,
1204                        );
1205                    }
1206                })
1207                .ok();
1208        });
1209
1210        self.pending_completions.push(PendingCompletion {
1211            id: pending_completion_id,
1212            _task: task,
1213        });
1214    }
1215
1216    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1217        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
1218            return;
1219        };
1220        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
1221            return;
1222        };
1223
1224        if !provider.is_authenticated(cx) {
1225            return;
1226        }
1227
1228        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1229        request.messages.push(LanguageModelRequestMessage {
1230            role: Role::User,
1231            content: vec![
1232                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1233                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1234                 If the conversation is about a specific subject, include it in the title. \
1235                 Be descriptive. DO NOT speak in the first person."
1236                    .into(),
1237            ],
1238            cache: false,
1239        });
1240
1241        self.pending_summary = cx.spawn(async move |this, cx| {
1242            async move {
1243                let stream = model.stream_completion_text(request, &cx);
1244                let mut messages = stream.await?;
1245
1246                let mut new_summary = String::new();
1247                while let Some(message) = messages.stream.next().await {
1248                    let text = message?;
1249                    let mut lines = text.lines();
1250                    new_summary.extend(lines.next());
1251
1252                    // Stop if the LLM generated multiple lines.
1253                    if lines.next().is_some() {
1254                        break;
1255                    }
1256                }
1257
1258                this.update(cx, |this, cx| {
1259                    if !new_summary.is_empty() {
1260                        this.summary = Some(new_summary.into());
1261                    }
1262
1263                    cx.emit(ThreadEvent::SummaryChanged);
1264                })?;
1265
1266                anyhow::Ok(())
1267            }
1268            .log_err()
1269            .await
1270        });
1271    }
1272
1273    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1274        let last_message_id = self.messages.last().map(|message| message.id)?;
1275
1276        match &self.detailed_summary_state {
1277            DetailedSummaryState::Generating { message_id, .. }
1278            | DetailedSummaryState::Generated { message_id, .. }
1279                if *message_id == last_message_id =>
1280            {
1281                // Already up-to-date
1282                return None;
1283            }
1284            _ => {}
1285        }
1286
1287        let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
1288        let model = LanguageModelRegistry::read_global(cx).active_model()?;
1289
1290        if !provider.is_authenticated(cx) {
1291            return None;
1292        }
1293
1294        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1295
1296        request.messages.push(LanguageModelRequestMessage {
1297            role: Role::User,
1298            content: vec![
1299                "Generate a detailed summary of this conversation. Include:\n\
1300                1. A brief overview of what was discussed\n\
1301                2. Key facts or information discovered\n\
1302                3. Outcomes or conclusions reached\n\
1303                4. Any action items or next steps if any\n\
1304                Format it in Markdown with headings and bullet points."
1305                    .into(),
1306            ],
1307            cache: false,
1308        });
1309
1310        let task = cx.spawn(async move |thread, cx| {
1311            let stream = model.stream_completion_text(request, &cx);
1312            let Some(mut messages) = stream.await.log_err() else {
1313                thread
1314                    .update(cx, |this, _cx| {
1315                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1316                    })
1317                    .log_err();
1318
1319                return;
1320            };
1321
1322            let mut new_detailed_summary = String::new();
1323
1324            while let Some(chunk) = messages.stream.next().await {
1325                if let Some(chunk) = chunk.log_err() {
1326                    new_detailed_summary.push_str(&chunk);
1327                }
1328            }
1329
1330            thread
1331                .update(cx, |this, _cx| {
1332                    this.detailed_summary_state = DetailedSummaryState::Generated {
1333                        text: new_detailed_summary.into(),
1334                        message_id: last_message_id,
1335                    };
1336                })
1337                .log_err();
1338        });
1339
1340        self.detailed_summary_state = DetailedSummaryState::Generating {
1341            message_id: last_message_id,
1342        };
1343
1344        Some(task)
1345    }
1346
1347    pub fn is_generating_detailed_summary(&self) -> bool {
1348        matches!(
1349            self.detailed_summary_state,
1350            DetailedSummaryState::Generating { .. }
1351        )
1352    }
1353
1354    pub fn use_pending_tools(
1355        &mut self,
1356        cx: &mut Context<Self>,
1357    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1358        let request = self.to_completion_request(RequestKind::Chat, cx);
1359        let messages = Arc::new(request.messages);
1360        let pending_tool_uses = self
1361            .tool_use
1362            .pending_tool_uses()
1363            .into_iter()
1364            .filter(|tool_use| tool_use.status.is_idle())
1365            .cloned()
1366            .collect::<Vec<_>>();
1367
1368        for tool_use in pending_tool_uses.iter() {
1369            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1370                if tool.needs_confirmation()
1371                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1372                {
1373                    self.tool_use.confirm_tool_use(
1374                        tool_use.id.clone(),
1375                        tool_use.ui_text.clone(),
1376                        tool_use.input.clone(),
1377                        messages.clone(),
1378                        tool,
1379                    );
1380                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1381                } else {
1382                    self.run_tool(
1383                        tool_use.id.clone(),
1384                        tool_use.ui_text.clone(),
1385                        tool_use.input.clone(),
1386                        &messages,
1387                        tool,
1388                        cx,
1389                    );
1390                }
1391            }
1392        }
1393
1394        pending_tool_uses
1395    }
1396
1397    pub fn run_tool(
1398        &mut self,
1399        tool_use_id: LanguageModelToolUseId,
1400        ui_text: impl Into<SharedString>,
1401        input: serde_json::Value,
1402        messages: &[LanguageModelRequestMessage],
1403        tool: Arc<dyn Tool>,
1404        cx: &mut Context<Thread>,
1405    ) {
1406        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1407        self.tool_use
1408            .run_pending_tool(tool_use_id, ui_text.into(), task);
1409    }
1410
1411    fn spawn_tool_use(
1412        &mut self,
1413        tool_use_id: LanguageModelToolUseId,
1414        messages: &[LanguageModelRequestMessage],
1415        input: serde_json::Value,
1416        tool: Arc<dyn Tool>,
1417        cx: &mut Context<Thread>,
1418    ) -> Task<()> {
1419        let tool_name: Arc<str> = tool.name().into();
1420
1421        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1422            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1423        } else {
1424            tool.run(
1425                input,
1426                messages,
1427                self.project.clone(),
1428                self.action_log.clone(),
1429                cx,
1430            )
1431        };
1432
1433        cx.spawn({
1434            async move |thread: WeakEntity<Thread>, cx| {
1435                let output = run_tool.await;
1436
1437                thread
1438                    .update(cx, |thread, cx| {
1439                        let pending_tool_use = thread.tool_use.insert_tool_output(
1440                            tool_use_id.clone(),
1441                            tool_name,
1442                            output,
1443                        );
1444
1445                        cx.emit(ThreadEvent::ToolFinished {
1446                            tool_use_id,
1447                            pending_tool_use,
1448                            canceled: false,
1449                        });
1450                    })
1451                    .ok();
1452            }
1453        })
1454    }
1455
1456    pub fn attach_tool_results(
1457        &mut self,
1458        updated_context: Vec<AssistantContext>,
1459        cx: &mut Context<Self>,
1460    ) {
1461        self.context.extend(
1462            updated_context
1463                .into_iter()
1464                .map(|context| (context.id(), context)),
1465        );
1466
1467        // Insert a user message to contain the tool results.
1468        self.insert_user_message(
1469            // TODO: Sending up a user message without any content results in the model sending back
1470            // responses that also don't have any content. We currently don't handle this case well,
1471            // so for now we provide some text to keep the model on track.
1472            "Here are the tool results.",
1473            Vec::new(),
1474            None,
1475            cx,
1476        );
1477    }
1478
1479    /// Cancels the last pending completion, if there are any pending.
1480    ///
1481    /// Returns whether a completion was canceled.
1482    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1483        let canceled = if self.pending_completions.pop().is_some() {
1484            true
1485        } else {
1486            let mut canceled = false;
1487            for pending_tool_use in self.tool_use.cancel_pending() {
1488                canceled = true;
1489                cx.emit(ThreadEvent::ToolFinished {
1490                    tool_use_id: pending_tool_use.id.clone(),
1491                    pending_tool_use: Some(pending_tool_use),
1492                    canceled: true,
1493                });
1494            }
1495            canceled
1496        };
1497        self.finalize_pending_checkpoint(cx);
1498        canceled
1499    }
1500
1501    /// Returns the feedback given to the thread, if any.
1502    pub fn feedback(&self) -> Option<ThreadFeedback> {
1503        self.feedback
1504    }
1505
1506    /// Reports feedback about the thread and stores it in our telemetry backend.
1507    pub fn report_feedback(
1508        &mut self,
1509        feedback: ThreadFeedback,
1510        cx: &mut Context<Self>,
1511    ) -> Task<Result<()>> {
1512        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1513        let serialized_thread = self.serialize(cx);
1514        let thread_id = self.id().clone();
1515        let client = self.project.read(cx).client();
1516        self.feedback = Some(feedback);
1517        cx.notify();
1518
1519        cx.background_spawn(async move {
1520            let final_project_snapshot = final_project_snapshot.await;
1521            let serialized_thread = serialized_thread.await?;
1522            let thread_data =
1523                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1524
1525            let rating = match feedback {
1526                ThreadFeedback::Positive => "positive",
1527                ThreadFeedback::Negative => "negative",
1528            };
1529            telemetry::event!(
1530                "Assistant Thread Rated",
1531                rating,
1532                thread_id,
1533                thread_data,
1534                final_project_snapshot
1535            );
1536            client.telemetry().flush_events();
1537
1538            Ok(())
1539        })
1540    }
1541
1542    /// Create a snapshot of the current project state including git information and unsaved buffers.
1543    fn project_snapshot(
1544        project: Entity<Project>,
1545        cx: &mut Context<Self>,
1546    ) -> Task<Arc<ProjectSnapshot>> {
1547        let git_store = project.read(cx).git_store().clone();
1548        let worktree_snapshots: Vec<_> = project
1549            .read(cx)
1550            .visible_worktrees(cx)
1551            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1552            .collect();
1553
1554        cx.spawn(async move |_, cx| {
1555            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1556
1557            let mut unsaved_buffers = Vec::new();
1558            cx.update(|app_cx| {
1559                let buffer_store = project.read(app_cx).buffer_store();
1560                for buffer_handle in buffer_store.read(app_cx).buffers() {
1561                    let buffer = buffer_handle.read(app_cx);
1562                    if buffer.is_dirty() {
1563                        if let Some(file) = buffer.file() {
1564                            let path = file.path().to_string_lossy().to_string();
1565                            unsaved_buffers.push(path);
1566                        }
1567                    }
1568                }
1569            })
1570            .ok();
1571
1572            Arc::new(ProjectSnapshot {
1573                worktree_snapshots,
1574                unsaved_buffer_paths: unsaved_buffers,
1575                timestamp: Utc::now(),
1576            })
1577        })
1578    }
1579
1580    fn worktree_snapshot(
1581        worktree: Entity<project::Worktree>,
1582        git_store: Entity<GitStore>,
1583        cx: &App,
1584    ) -> Task<WorktreeSnapshot> {
1585        cx.spawn(async move |cx| {
1586            // Get worktree path and snapshot
1587            let worktree_info = cx.update(|app_cx| {
1588                let worktree = worktree.read(app_cx);
1589                let path = worktree.abs_path().to_string_lossy().to_string();
1590                let snapshot = worktree.snapshot();
1591                (path, snapshot)
1592            });
1593
1594            let Ok((worktree_path, _snapshot)) = worktree_info else {
1595                return WorktreeSnapshot {
1596                    worktree_path: String::new(),
1597                    git_state: None,
1598                };
1599            };
1600
1601            let git_state = git_store
1602                .update(cx, |git_store, cx| {
1603                    git_store
1604                        .repositories()
1605                        .values()
1606                        .find(|repo| {
1607                            repo.read(cx)
1608                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1609                                .is_some()
1610                        })
1611                        .cloned()
1612                })
1613                .ok()
1614                .flatten()
1615                .map(|repo| {
1616                    repo.read_with(cx, |repo, _| {
1617                        let current_branch =
1618                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1619                        repo.send_job(|state, _| async move {
1620                            let RepositoryState::Local { backend, .. } = state else {
1621                                return GitState {
1622                                    remote_url: None,
1623                                    head_sha: None,
1624                                    current_branch,
1625                                    diff: None,
1626                                };
1627                            };
1628
1629                            let remote_url = backend.remote_url("origin");
1630                            let head_sha = backend.head_sha();
1631                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1632
1633                            GitState {
1634                                remote_url,
1635                                head_sha,
1636                                current_branch,
1637                                diff,
1638                            }
1639                        })
1640                    })
1641                });
1642
1643            let git_state = match git_state {
1644                Some(git_state) => match git_state.ok() {
1645                    Some(git_state) => git_state.await.ok(),
1646                    None => None,
1647                },
1648                None => None,
1649            };
1650
1651            WorktreeSnapshot {
1652                worktree_path,
1653                git_state,
1654            }
1655        })
1656    }
1657
1658    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1659        let mut markdown = Vec::new();
1660
1661        if let Some(summary) = self.summary() {
1662            writeln!(markdown, "# {summary}\n")?;
1663        };
1664
1665        for message in self.messages() {
1666            writeln!(
1667                markdown,
1668                "## {role}\n",
1669                role = match message.role {
1670                    Role::User => "User",
1671                    Role::Assistant => "Assistant",
1672                    Role::System => "System",
1673                }
1674            )?;
1675            for segment in &message.segments {
1676                match segment {
1677                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1678                    MessageSegment::Thinking(text) => {
1679                        writeln!(markdown, "<think>{}</think>\n", text)?
1680                    }
1681                }
1682            }
1683
1684            for tool_use in self.tool_uses_for_message(message.id, cx) {
1685                writeln!(
1686                    markdown,
1687                    "**Use Tool: {} ({})**",
1688                    tool_use.name, tool_use.id
1689                )?;
1690                writeln!(markdown, "```json")?;
1691                writeln!(
1692                    markdown,
1693                    "{}",
1694                    serde_json::to_string_pretty(&tool_use.input)?
1695                )?;
1696                writeln!(markdown, "```")?;
1697            }
1698
1699            for tool_result in self.tool_results_for_message(message.id) {
1700                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1701                if tool_result.is_error {
1702                    write!(markdown, " (Error)")?;
1703                }
1704
1705                writeln!(markdown, "**\n")?;
1706                writeln!(markdown, "{}", tool_result.content)?;
1707            }
1708        }
1709
1710        Ok(String::from_utf8_lossy(&markdown).to_string())
1711    }
1712
1713    pub fn keep_edits_in_range(
1714        &mut self,
1715        buffer: Entity<language::Buffer>,
1716        buffer_range: Range<language::Anchor>,
1717        cx: &mut Context<Self>,
1718    ) {
1719        self.action_log.update(cx, |action_log, cx| {
1720            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1721        });
1722    }
1723
1724    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1725        self.action_log
1726            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1727    }
1728
1729    pub fn action_log(&self) -> &Entity<ActionLog> {
1730        &self.action_log
1731    }
1732
1733    pub fn project(&self) -> &Entity<Project> {
1734        &self.project
1735    }
1736
1737    pub fn cumulative_token_usage(&self) -> TokenUsage {
1738        self.cumulative_token_usage.clone()
1739    }
1740
1741    pub fn total_token_usage(&self, cx: &App) -> TotalTokenUsage {
1742        let model_registry = LanguageModelRegistry::read_global(cx);
1743        let Some(model) = model_registry.active_model() else {
1744            return TotalTokenUsage::default();
1745        };
1746
1747        let max = model.max_token_count();
1748
1749        #[cfg(debug_assertions)]
1750        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1751            .unwrap_or("0.8".to_string())
1752            .parse()
1753            .unwrap();
1754        #[cfg(not(debug_assertions))]
1755        let warning_threshold: f32 = 0.8;
1756
1757        let total = self.cumulative_token_usage.total_tokens() as usize;
1758
1759        let ratio = if total >= max {
1760            TokenUsageRatio::Exceeded
1761        } else if total as f32 / max as f32 >= warning_threshold {
1762            TokenUsageRatio::Warning
1763        } else {
1764            TokenUsageRatio::Normal
1765        };
1766
1767        TotalTokenUsage { total, max, ratio }
1768    }
1769
1770    pub fn deny_tool_use(
1771        &mut self,
1772        tool_use_id: LanguageModelToolUseId,
1773        tool_name: Arc<str>,
1774        cx: &mut Context<Self>,
1775    ) {
1776        let err = Err(anyhow::anyhow!(
1777            "Permission to run tool action denied by user"
1778        ));
1779
1780        self.tool_use
1781            .insert_tool_output(tool_use_id.clone(), tool_name, err);
1782
1783        cx.emit(ThreadEvent::ToolFinished {
1784            tool_use_id,
1785            pending_tool_use: None,
1786            canceled: true,
1787        });
1788    }
1789}
1790
1791#[derive(Debug, Clone)]
1792pub enum ThreadError {
1793    PaymentRequired,
1794    MaxMonthlySpendReached,
1795    Message {
1796        header: SharedString,
1797        message: SharedString,
1798    },
1799}
1800
1801#[derive(Debug, Clone)]
1802pub enum ThreadEvent {
1803    ShowError(ThreadError),
1804    StreamedCompletion,
1805    StreamedAssistantText(MessageId, String),
1806    StreamedAssistantThinking(MessageId, String),
1807    DoneStreaming,
1808    MessageAdded(MessageId),
1809    MessageEdited(MessageId),
1810    MessageDeleted(MessageId),
1811    SummaryChanged,
1812    UsePendingTools,
1813    ToolFinished {
1814        #[allow(unused)]
1815        tool_use_id: LanguageModelToolUseId,
1816        /// The pending tool use that corresponds to this tool.
1817        pending_tool_use: Option<PendingToolUse>,
1818        /// Whether the tool was canceled by the user.
1819        canceled: bool,
1820    },
1821    CheckpointChanged,
1822    ToolConfirmationNeeded,
1823}
1824
1825impl EventEmitter<ThreadEvent> for Thread {}
1826
1827struct PendingCompletion {
1828    id: usize,
1829    _task: Task<()>,
1830}