thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result, anyhow};
   7use assistant_settings::AssistantSettings;
   8use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::{BTreeMap, HashMap, HashSet};
  11use fs::Fs;
  12use futures::future::Shared;
  13use futures::{FutureExt, StreamExt as _};
  14use git::repository::DiffType;
  15use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  18    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  19    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  20    Role, StopReason, TokenUsage,
  21};
  22use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  23use project::{Project, Worktree};
  24use prompt_store::{
  25    AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
  26};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
  31use uuid::Uuid;
  32
  33use crate::context::{AssistantContext, ContextId, attach_context_to_message};
  34use crate::thread_store::{
  35    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  36    SerializedToolUse,
  37};
  38use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
  39
  40#[derive(Debug, Clone, Copy)]
  41pub enum RequestKind {
  42    Chat,
  43    /// Used when summarizing a thread.
  44    Summarize,
  45}
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  49)]
  50pub struct ThreadId(Arc<str>);
  51
  52impl ThreadId {
  53    pub fn new() -> Self {
  54        Self(Uuid::new_v4().to_string().into())
  55    }
  56}
  57
  58impl std::fmt::Display for ThreadId {
  59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  60        write!(f, "{}", self.0)
  61    }
  62}
  63
  64impl From<&str> for ThreadId {
  65    fn from(value: &str) -> Self {
  66        Self(value.into())
  67    }
  68}
  69
  70#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  71pub struct MessageId(pub(crate) usize);
  72
  73impl MessageId {
  74    fn post_inc(&mut self) -> Self {
  75        Self(post_inc(&mut self.0))
  76    }
  77}
  78
  79/// A message in a [`Thread`].
  80#[derive(Debug, Clone)]
  81pub struct Message {
  82    pub id: MessageId,
  83    pub role: Role,
  84    pub segments: Vec<MessageSegment>,
  85}
  86
  87impl Message {
  88    pub fn push_thinking(&mut self, text: &str) {
  89        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  90            segment.push_str(text);
  91        } else {
  92            self.segments
  93                .push(MessageSegment::Thinking(text.to_string()));
  94        }
  95    }
  96
  97    pub fn push_text(&mut self, text: &str) {
  98        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
  99            segment.push_str(text);
 100        } else {
 101            self.segments.push(MessageSegment::Text(text.to_string()));
 102        }
 103    }
 104
 105    pub fn to_string(&self) -> String {
 106        let mut result = String::new();
 107        for segment in &self.segments {
 108            match segment {
 109                MessageSegment::Text(text) => result.push_str(text),
 110                MessageSegment::Thinking(text) => {
 111                    result.push_str("<think>");
 112                    result.push_str(text);
 113                    result.push_str("</think>");
 114                }
 115            }
 116        }
 117        result
 118    }
 119}
 120
 121#[derive(Debug, Clone)]
 122pub enum MessageSegment {
 123    Text(String),
 124    Thinking(String),
 125}
 126
 127impl MessageSegment {
 128    pub fn text_mut(&mut self) -> &mut String {
 129        match self {
 130            Self::Text(text) => text,
 131            Self::Thinking(text) => text,
 132        }
 133    }
 134}
 135
 136#[derive(Debug, Clone, Serialize, Deserialize)]
 137pub struct ProjectSnapshot {
 138    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 139    pub unsaved_buffer_paths: Vec<String>,
 140    pub timestamp: DateTime<Utc>,
 141}
 142
 143#[derive(Debug, Clone, Serialize, Deserialize)]
 144pub struct WorktreeSnapshot {
 145    pub worktree_path: String,
 146    pub git_state: Option<GitState>,
 147}
 148
 149#[derive(Debug, Clone, Serialize, Deserialize)]
 150pub struct GitState {
 151    pub remote_url: Option<String>,
 152    pub head_sha: Option<String>,
 153    pub current_branch: Option<String>,
 154    pub diff: Option<String>,
 155}
 156
 157#[derive(Clone)]
 158pub struct ThreadCheckpoint {
 159    message_id: MessageId,
 160    git_checkpoint: GitStoreCheckpoint,
 161}
 162
 163#[derive(Copy, Clone, Debug)]
 164pub enum ThreadFeedback {
 165    Positive,
 166    Negative,
 167}
 168
 169pub enum LastRestoreCheckpoint {
 170    Pending {
 171        message_id: MessageId,
 172    },
 173    Error {
 174        message_id: MessageId,
 175        error: String,
 176    },
 177}
 178
 179impl LastRestoreCheckpoint {
 180    pub fn message_id(&self) -> MessageId {
 181        match self {
 182            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 183            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 184        }
 185    }
 186}
 187
 188#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 189pub enum DetailedSummaryState {
 190    #[default]
 191    NotGenerated,
 192    Generating {
 193        message_id: MessageId,
 194    },
 195    Generated {
 196        text: SharedString,
 197        message_id: MessageId,
 198    },
 199}
 200
 201/// A thread of conversation with the LLM.
 202pub struct Thread {
 203    id: ThreadId,
 204    updated_at: DateTime<Utc>,
 205    summary: Option<SharedString>,
 206    pending_summary: Task<Option<()>>,
 207    detailed_summary_state: DetailedSummaryState,
 208    messages: Vec<Message>,
 209    next_message_id: MessageId,
 210    context: BTreeMap<ContextId, AssistantContext>,
 211    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 212    system_prompt_context: Option<AssistantSystemPromptContext>,
 213    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 214    completion_count: usize,
 215    pending_completions: Vec<PendingCompletion>,
 216    project: Entity<Project>,
 217    prompt_builder: Arc<PromptBuilder>,
 218    tools: Arc<ToolWorkingSet>,
 219    tool_use: ToolUseState,
 220    action_log: Entity<ActionLog>,
 221    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 222    pending_checkpoint: Option<ThreadCheckpoint>,
 223    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 224    cumulative_token_usage: TokenUsage,
 225    feedback: Option<ThreadFeedback>,
 226}
 227
 228impl Thread {
 229    pub fn new(
 230        project: Entity<Project>,
 231        tools: Arc<ToolWorkingSet>,
 232        prompt_builder: Arc<PromptBuilder>,
 233        cx: &mut Context<Self>,
 234    ) -> Self {
 235        Self {
 236            id: ThreadId::new(),
 237            updated_at: Utc::now(),
 238            summary: None,
 239            pending_summary: Task::ready(None),
 240            detailed_summary_state: DetailedSummaryState::NotGenerated,
 241            messages: Vec::new(),
 242            next_message_id: MessageId(0),
 243            context: BTreeMap::default(),
 244            context_by_message: HashMap::default(),
 245            system_prompt_context: None,
 246            checkpoints_by_message: HashMap::default(),
 247            completion_count: 0,
 248            pending_completions: Vec::new(),
 249            project: project.clone(),
 250            prompt_builder,
 251            tools: tools.clone(),
 252            last_restore_checkpoint: None,
 253            pending_checkpoint: None,
 254            tool_use: ToolUseState::new(tools.clone()),
 255            action_log: cx.new(|_| ActionLog::new()),
 256            initial_project_snapshot: {
 257                let project_snapshot = Self::project_snapshot(project, cx);
 258                cx.foreground_executor()
 259                    .spawn(async move { Some(project_snapshot.await) })
 260                    .shared()
 261            },
 262            cumulative_token_usage: TokenUsage::default(),
 263            feedback: None,
 264        }
 265    }
 266
 267    pub fn deserialize(
 268        id: ThreadId,
 269        serialized: SerializedThread,
 270        project: Entity<Project>,
 271        tools: Arc<ToolWorkingSet>,
 272        prompt_builder: Arc<PromptBuilder>,
 273        cx: &mut Context<Self>,
 274    ) -> Self {
 275        let next_message_id = MessageId(
 276            serialized
 277                .messages
 278                .last()
 279                .map(|message| message.id.0 + 1)
 280                .unwrap_or(0),
 281        );
 282        let tool_use =
 283            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 284
 285        Self {
 286            id,
 287            updated_at: serialized.updated_at,
 288            summary: Some(serialized.summary),
 289            pending_summary: Task::ready(None),
 290            detailed_summary_state: serialized.detailed_summary_state,
 291            messages: serialized
 292                .messages
 293                .into_iter()
 294                .map(|message| Message {
 295                    id: message.id,
 296                    role: message.role,
 297                    segments: message
 298                        .segments
 299                        .into_iter()
 300                        .map(|segment| match segment {
 301                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 302                            SerializedMessageSegment::Thinking { text } => {
 303                                MessageSegment::Thinking(text)
 304                            }
 305                        })
 306                        .collect(),
 307                })
 308                .collect(),
 309            next_message_id,
 310            context: BTreeMap::default(),
 311            context_by_message: HashMap::default(),
 312            system_prompt_context: None,
 313            checkpoints_by_message: HashMap::default(),
 314            completion_count: 0,
 315            pending_completions: Vec::new(),
 316            last_restore_checkpoint: None,
 317            pending_checkpoint: None,
 318            project,
 319            prompt_builder,
 320            tools,
 321            tool_use,
 322            action_log: cx.new(|_| ActionLog::new()),
 323            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 324            cumulative_token_usage: serialized.cumulative_token_usage,
 325            feedback: None,
 326        }
 327    }
 328
 329    pub fn id(&self) -> &ThreadId {
 330        &self.id
 331    }
 332
 333    pub fn is_empty(&self) -> bool {
 334        self.messages.is_empty()
 335    }
 336
 337    pub fn updated_at(&self) -> DateTime<Utc> {
 338        self.updated_at
 339    }
 340
 341    pub fn touch_updated_at(&mut self) {
 342        self.updated_at = Utc::now();
 343    }
 344
 345    pub fn summary(&self) -> Option<SharedString> {
 346        self.summary.clone()
 347    }
 348
 349    pub fn summary_or_default(&self) -> SharedString {
 350        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 351        self.summary.clone().unwrap_or(DEFAULT)
 352    }
 353
 354    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 355        self.summary = Some(summary.into());
 356        cx.emit(ThreadEvent::SummaryChanged);
 357    }
 358
 359    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 360        self.latest_detailed_summary()
 361            .unwrap_or_else(|| self.text().into())
 362    }
 363
 364    fn latest_detailed_summary(&self) -> Option<SharedString> {
 365        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 366            Some(text.clone())
 367        } else {
 368            None
 369        }
 370    }
 371
 372    pub fn message(&self, id: MessageId) -> Option<&Message> {
 373        self.messages.iter().find(|message| message.id == id)
 374    }
 375
 376    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 377        self.messages.iter()
 378    }
 379
 380    pub fn is_generating(&self) -> bool {
 381        !self.pending_completions.is_empty() || !self.all_tools_finished()
 382    }
 383
 384    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 385        &self.tools
 386    }
 387
 388    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 389        self.tool_use
 390            .pending_tool_uses()
 391            .into_iter()
 392            .find(|tool_use| &tool_use.id == id)
 393    }
 394
 395    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 396        self.tool_use
 397            .pending_tool_uses()
 398            .into_iter()
 399            .filter(|tool_use| tool_use.status.needs_confirmation())
 400    }
 401
 402    pub fn has_pending_tool_uses(&self) -> bool {
 403        !self.tool_use.pending_tool_uses().is_empty()
 404    }
 405
 406    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 407        self.checkpoints_by_message.get(&id).cloned()
 408    }
 409
 410    pub fn restore_checkpoint(
 411        &mut self,
 412        checkpoint: ThreadCheckpoint,
 413        cx: &mut Context<Self>,
 414    ) -> Task<Result<()>> {
 415        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 416            message_id: checkpoint.message_id,
 417        });
 418        cx.emit(ThreadEvent::CheckpointChanged);
 419        cx.notify();
 420
 421        let project = self.project.read(cx);
 422        let restore = project
 423            .git_store()
 424            .read(cx)
 425            .restore_checkpoint(checkpoint.git_checkpoint.clone(), cx);
 426        cx.spawn(async move |this, cx| {
 427            let result = restore.await;
 428            this.update(cx, |this, cx| {
 429                if let Err(err) = result.as_ref() {
 430                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 431                        message_id: checkpoint.message_id,
 432                        error: err.to_string(),
 433                    });
 434                } else {
 435                    this.truncate(checkpoint.message_id, cx);
 436                    this.last_restore_checkpoint = None;
 437                }
 438                this.pending_checkpoint = None;
 439                cx.emit(ThreadEvent::CheckpointChanged);
 440                cx.notify();
 441            })?;
 442            result
 443        })
 444    }
 445
 446    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 447        let pending_checkpoint = if self.is_generating() {
 448            return;
 449        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 450            checkpoint
 451        } else {
 452            return;
 453        };
 454
 455        let git_store = self.project.read(cx).git_store().clone();
 456        let final_checkpoint = git_store.read(cx).checkpoint(cx);
 457        cx.spawn(async move |this, cx| match final_checkpoint.await {
 458            Ok(final_checkpoint) => {
 459                let equal = git_store
 460                    .read_with(cx, |store, cx| {
 461                        store.compare_checkpoints(
 462                            pending_checkpoint.git_checkpoint.clone(),
 463                            final_checkpoint.clone(),
 464                            cx,
 465                        )
 466                    })?
 467                    .await
 468                    .unwrap_or(false);
 469
 470                if equal {
 471                    git_store
 472                        .read_with(cx, |store, cx| {
 473                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 474                        })?
 475                        .detach();
 476                } else {
 477                    this.update(cx, |this, cx| {
 478                        this.insert_checkpoint(pending_checkpoint, cx)
 479                    })?;
 480                }
 481
 482                git_store
 483                    .read_with(cx, |store, cx| {
 484                        store.delete_checkpoint(final_checkpoint, cx)
 485                    })?
 486                    .detach();
 487
 488                Ok(())
 489            }
 490            Err(_) => this.update(cx, |this, cx| {
 491                this.insert_checkpoint(pending_checkpoint, cx)
 492            }),
 493        })
 494        .detach();
 495    }
 496
 497    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 498        self.checkpoints_by_message
 499            .insert(checkpoint.message_id, checkpoint);
 500        cx.emit(ThreadEvent::CheckpointChanged);
 501        cx.notify();
 502    }
 503
 504    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 505        self.last_restore_checkpoint.as_ref()
 506    }
 507
 508    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 509        let Some(message_ix) = self
 510            .messages
 511            .iter()
 512            .rposition(|message| message.id == message_id)
 513        else {
 514            return;
 515        };
 516        for deleted_message in self.messages.drain(message_ix..) {
 517            self.context_by_message.remove(&deleted_message.id);
 518            self.checkpoints_by_message.remove(&deleted_message.id);
 519        }
 520        cx.notify();
 521    }
 522
 523    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 524        self.context_by_message
 525            .get(&id)
 526            .into_iter()
 527            .flat_map(|context| {
 528                context
 529                    .iter()
 530                    .filter_map(|context_id| self.context.get(&context_id))
 531            })
 532    }
 533
 534    /// Returns whether all of the tool uses have finished running.
 535    pub fn all_tools_finished(&self) -> bool {
 536        // If the only pending tool uses left are the ones with errors, then
 537        // that means that we've finished running all of the pending tools.
 538        self.tool_use
 539            .pending_tool_uses()
 540            .iter()
 541            .all(|tool_use| tool_use.status.is_error())
 542    }
 543
 544    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 545        self.tool_use.tool_uses_for_message(id, cx)
 546    }
 547
 548    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 549        self.tool_use.tool_results_for_message(id)
 550    }
 551
 552    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 553        self.tool_use.tool_result(id)
 554    }
 555
 556    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 557        self.tool_use.message_has_tool_results(message_id)
 558    }
 559
 560    pub fn insert_user_message(
 561        &mut self,
 562        text: impl Into<String>,
 563        context: Vec<AssistantContext>,
 564        git_checkpoint: Option<GitStoreCheckpoint>,
 565        cx: &mut Context<Self>,
 566    ) -> MessageId {
 567        let message_id =
 568            self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
 569        let context_ids = context
 570            .iter()
 571            .map(|context| context.id())
 572            .collect::<Vec<_>>();
 573        self.context
 574            .extend(context.into_iter().map(|context| (context.id(), context)));
 575        self.context_by_message.insert(message_id, context_ids);
 576        if let Some(git_checkpoint) = git_checkpoint {
 577            self.pending_checkpoint = Some(ThreadCheckpoint {
 578                message_id,
 579                git_checkpoint,
 580            });
 581        }
 582        message_id
 583    }
 584
 585    pub fn insert_message(
 586        &mut self,
 587        role: Role,
 588        segments: Vec<MessageSegment>,
 589        cx: &mut Context<Self>,
 590    ) -> MessageId {
 591        let id = self.next_message_id.post_inc();
 592        self.messages.push(Message { id, role, segments });
 593        self.touch_updated_at();
 594        cx.emit(ThreadEvent::MessageAdded(id));
 595        id
 596    }
 597
 598    pub fn edit_message(
 599        &mut self,
 600        id: MessageId,
 601        new_role: Role,
 602        new_segments: Vec<MessageSegment>,
 603        cx: &mut Context<Self>,
 604    ) -> bool {
 605        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 606            return false;
 607        };
 608        message.role = new_role;
 609        message.segments = new_segments;
 610        self.touch_updated_at();
 611        cx.emit(ThreadEvent::MessageEdited(id));
 612        true
 613    }
 614
 615    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 616        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 617            return false;
 618        };
 619        self.messages.remove(index);
 620        self.context_by_message.remove(&id);
 621        self.touch_updated_at();
 622        cx.emit(ThreadEvent::MessageDeleted(id));
 623        true
 624    }
 625
 626    /// Returns the representation of this [`Thread`] in a textual form.
 627    ///
 628    /// This is the representation we use when attaching a thread as context to another thread.
 629    pub fn text(&self) -> String {
 630        let mut text = String::new();
 631
 632        for message in &self.messages {
 633            text.push_str(match message.role {
 634                language_model::Role::User => "User:",
 635                language_model::Role::Assistant => "Assistant:",
 636                language_model::Role::System => "System:",
 637            });
 638            text.push('\n');
 639
 640            for segment in &message.segments {
 641                match segment {
 642                    MessageSegment::Text(content) => text.push_str(content),
 643                    MessageSegment::Thinking(content) => {
 644                        text.push_str(&format!("<think>{}</think>", content))
 645                    }
 646                }
 647            }
 648            text.push('\n');
 649        }
 650
 651        text
 652    }
 653
 654    /// Serializes this thread into a format for storage or telemetry.
 655    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 656        let initial_project_snapshot = self.initial_project_snapshot.clone();
 657        cx.spawn(async move |this, cx| {
 658            let initial_project_snapshot = initial_project_snapshot.await;
 659            this.read_with(cx, |this, cx| SerializedThread {
 660                version: SerializedThread::VERSION.to_string(),
 661                summary: this.summary_or_default(),
 662                updated_at: this.updated_at(),
 663                messages: this
 664                    .messages()
 665                    .map(|message| SerializedMessage {
 666                        id: message.id,
 667                        role: message.role,
 668                        segments: message
 669                            .segments
 670                            .iter()
 671                            .map(|segment| match segment {
 672                                MessageSegment::Text(text) => {
 673                                    SerializedMessageSegment::Text { text: text.clone() }
 674                                }
 675                                MessageSegment::Thinking(text) => {
 676                                    SerializedMessageSegment::Thinking { text: text.clone() }
 677                                }
 678                            })
 679                            .collect(),
 680                        tool_uses: this
 681                            .tool_uses_for_message(message.id, cx)
 682                            .into_iter()
 683                            .map(|tool_use| SerializedToolUse {
 684                                id: tool_use.id,
 685                                name: tool_use.name,
 686                                input: tool_use.input,
 687                            })
 688                            .collect(),
 689                        tool_results: this
 690                            .tool_results_for_message(message.id)
 691                            .into_iter()
 692                            .map(|tool_result| SerializedToolResult {
 693                                tool_use_id: tool_result.tool_use_id.clone(),
 694                                is_error: tool_result.is_error,
 695                                content: tool_result.content.clone(),
 696                            })
 697                            .collect(),
 698                    })
 699                    .collect(),
 700                initial_project_snapshot,
 701                cumulative_token_usage: this.cumulative_token_usage.clone(),
 702                detailed_summary_state: this.detailed_summary_state.clone(),
 703            })
 704        })
 705    }
 706
 707    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 708        self.system_prompt_context = Some(context);
 709    }
 710
 711    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 712        &self.system_prompt_context
 713    }
 714
 715    pub fn load_system_prompt_context(
 716        &self,
 717        cx: &App,
 718    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 719        let project = self.project.read(cx);
 720        let tasks = project
 721            .visible_worktrees(cx)
 722            .map(|worktree| {
 723                Self::load_worktree_info_for_system_prompt(
 724                    project.fs().clone(),
 725                    worktree.read(cx),
 726                    cx,
 727                )
 728            })
 729            .collect::<Vec<_>>();
 730
 731        cx.spawn(async |_cx| {
 732            let results = futures::future::join_all(tasks).await;
 733            let mut first_err = None;
 734            let worktrees = results
 735                .into_iter()
 736                .map(|(worktree, err)| {
 737                    if first_err.is_none() && err.is_some() {
 738                        first_err = err;
 739                    }
 740                    worktree
 741                })
 742                .collect::<Vec<_>>();
 743            (AssistantSystemPromptContext::new(worktrees), first_err)
 744        })
 745    }
 746
 747    fn load_worktree_info_for_system_prompt(
 748        fs: Arc<dyn Fs>,
 749        worktree: &Worktree,
 750        cx: &App,
 751    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 752        let root_name = worktree.root_name().into();
 753        let abs_path = worktree.abs_path();
 754
 755        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 756        // supported. This doesn't seem to occur often in GitHub repositories.
 757        const RULES_FILE_NAMES: [&'static str; 6] = [
 758            ".rules",
 759            ".cursorrules",
 760            ".windsurfrules",
 761            ".clinerules",
 762            ".github/copilot-instructions.md",
 763            "CLAUDE.md",
 764        ];
 765        let selected_rules_file = RULES_FILE_NAMES
 766            .into_iter()
 767            .filter_map(|name| {
 768                worktree
 769                    .entry_for_path(name)
 770                    .filter(|entry| entry.is_file())
 771                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
 772            })
 773            .next();
 774
 775        if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
 776            cx.spawn(async move |_| {
 777                let rules_file_result = maybe!(async move {
 778                    let abs_rules_path = abs_rules_path?;
 779                    let text = fs.load(&abs_rules_path).await.with_context(|| {
 780                        format!("Failed to load assistant rules file {:?}", abs_rules_path)
 781                    })?;
 782                    anyhow::Ok(RulesFile {
 783                        rel_path: rel_rules_path,
 784                        abs_path: abs_rules_path.into(),
 785                        text: text.trim().to_string(),
 786                    })
 787                })
 788                .await;
 789                let (rules_file, rules_file_error) = match rules_file_result {
 790                    Ok(rules_file) => (Some(rules_file), None),
 791                    Err(err) => (
 792                        None,
 793                        Some(ThreadError::Message {
 794                            header: "Error loading rules file".into(),
 795                            message: format!("{err}").into(),
 796                        }),
 797                    ),
 798                };
 799                let worktree_info = WorktreeInfoForSystemPrompt {
 800                    root_name,
 801                    abs_path,
 802                    rules_file,
 803                };
 804                (worktree_info, rules_file_error)
 805            })
 806        } else {
 807            Task::ready((
 808                WorktreeInfoForSystemPrompt {
 809                    root_name,
 810                    abs_path,
 811                    rules_file: None,
 812                },
 813                None,
 814            ))
 815        }
 816    }
 817
 818    pub fn send_to_model(
 819        &mut self,
 820        model: Arc<dyn LanguageModel>,
 821        request_kind: RequestKind,
 822        cx: &mut Context<Self>,
 823    ) {
 824        let mut request = self.to_completion_request(request_kind, cx);
 825        if model.supports_tools() {
 826            request.tools = {
 827                let mut tools = Vec::new();
 828                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 829                    LanguageModelRequestTool {
 830                        name: tool.name(),
 831                        description: tool.description(),
 832                        input_schema: tool.input_schema(model.tool_input_format()),
 833                    }
 834                }));
 835
 836                tools
 837            };
 838        }
 839
 840        self.stream_completion(request, model, cx);
 841    }
 842
 843    pub fn used_tools_since_last_user_message(&self) -> bool {
 844        for message in self.messages.iter().rev() {
 845            if self.tool_use.message_has_tool_results(message.id) {
 846                return true;
 847            } else if message.role == Role::User {
 848                return false;
 849            }
 850        }
 851
 852        false
 853    }
 854
 855    pub fn to_completion_request(
 856        &self,
 857        request_kind: RequestKind,
 858        cx: &App,
 859    ) -> LanguageModelRequest {
 860        let mut request = LanguageModelRequest {
 861            messages: vec![],
 862            tools: Vec::new(),
 863            stop: Vec::new(),
 864            temperature: None,
 865        };
 866
 867        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 868            if let Some(system_prompt) = self
 869                .prompt_builder
 870                .generate_assistant_system_prompt(system_prompt_context)
 871                .context("failed to generate assistant system prompt")
 872                .log_err()
 873            {
 874                request.messages.push(LanguageModelRequestMessage {
 875                    role: Role::System,
 876                    content: vec![MessageContent::Text(system_prompt)],
 877                    cache: true,
 878                });
 879            }
 880        } else {
 881            log::error!("system_prompt_context not set.")
 882        }
 883
 884        let mut added_context_ids = HashSet::<ContextId>::default();
 885
 886        for message in &self.messages {
 887            let mut request_message = LanguageModelRequestMessage {
 888                role: message.role,
 889                content: Vec::new(),
 890                cache: false,
 891            };
 892
 893            match request_kind {
 894                RequestKind::Chat => {
 895                    self.tool_use
 896                        .attach_tool_results(message.id, &mut request_message);
 897                }
 898                RequestKind::Summarize => {
 899                    // We don't care about tool use during summarization.
 900                    if self.tool_use.message_has_tool_results(message.id) {
 901                        continue;
 902                    }
 903                }
 904            }
 905
 906            // Attach context to this message if it's the first to reference it
 907            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 908                let new_context_ids: Vec<_> = context_ids
 909                    .iter()
 910                    .filter(|id| !added_context_ids.contains(id))
 911                    .collect();
 912
 913                if !new_context_ids.is_empty() {
 914                    let referenced_context = new_context_ids
 915                        .iter()
 916                        .filter_map(|context_id| self.context.get(*context_id));
 917
 918                    attach_context_to_message(&mut request_message, referenced_context, cx);
 919                    added_context_ids.extend(context_ids.iter());
 920                }
 921            }
 922
 923            if !message.segments.is_empty() {
 924                request_message
 925                    .content
 926                    .push(MessageContent::Text(message.to_string()));
 927            }
 928
 929            match request_kind {
 930                RequestKind::Chat => {
 931                    self.tool_use
 932                        .attach_tool_uses(message.id, &mut request_message);
 933                }
 934                RequestKind::Summarize => {
 935                    // We don't care about tool use during summarization.
 936                }
 937            };
 938
 939            request.messages.push(request_message);
 940        }
 941
 942        // Set a cache breakpoint at the second-to-last message.
 943        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 944        let breakpoint_index = request.messages.len() - 2;
 945        for (index, message) in request.messages.iter_mut().enumerate() {
 946            message.cache = index == breakpoint_index;
 947        }
 948
 949        self.attached_tracked_files_state(&mut request.messages, cx);
 950
 951        request
 952    }
 953
 954    fn attached_tracked_files_state(
 955        &self,
 956        messages: &mut Vec<LanguageModelRequestMessage>,
 957        cx: &App,
 958    ) {
 959        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 960
 961        let mut stale_message = String::new();
 962
 963        let action_log = self.action_log.read(cx);
 964
 965        for stale_file in action_log.stale_buffers(cx) {
 966            let Some(file) = stale_file.read(cx).file() else {
 967                continue;
 968            };
 969
 970            if stale_message.is_empty() {
 971                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
 972            }
 973
 974            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 975        }
 976
 977        let mut content = Vec::with_capacity(2);
 978
 979        if !stale_message.is_empty() {
 980            content.push(stale_message.into());
 981        }
 982
 983        if action_log.has_edited_files_since_project_diagnostics_check() {
 984            content.push(
 985                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 986                and fix all errors AND warnings you introduced! \
 987                DO NOT mention you're going to do this until you're done."
 988                    .into(),
 989            );
 990        }
 991
 992        if !content.is_empty() {
 993            let context_message = LanguageModelRequestMessage {
 994                role: Role::User,
 995                content,
 996                cache: false,
 997            };
 998
 999            messages.push(context_message);
1000        }
1001    }
1002
1003    pub fn stream_completion(
1004        &mut self,
1005        request: LanguageModelRequest,
1006        model: Arc<dyn LanguageModel>,
1007        cx: &mut Context<Self>,
1008    ) {
1009        let pending_completion_id = post_inc(&mut self.completion_count);
1010
1011        let task = cx.spawn(async move |thread, cx| {
1012            let stream = model.stream_completion(request, &cx);
1013            let initial_token_usage =
1014                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1015            let stream_completion = async {
1016                let mut events = stream.await?;
1017                let mut stop_reason = StopReason::EndTurn;
1018                let mut current_token_usage = TokenUsage::default();
1019
1020                while let Some(event) = events.next().await {
1021                    let event = event?;
1022
1023                    thread.update(cx, |thread, cx| {
1024                        match event {
1025                            LanguageModelCompletionEvent::StartMessage { .. } => {
1026                                thread.insert_message(
1027                                    Role::Assistant,
1028                                    vec![MessageSegment::Text(String::new())],
1029                                    cx,
1030                                );
1031                            }
1032                            LanguageModelCompletionEvent::Stop(reason) => {
1033                                stop_reason = reason;
1034                            }
1035                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1036                                thread.cumulative_token_usage =
1037                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1038                                        - current_token_usage.clone();
1039                                current_token_usage = token_usage;
1040                            }
1041                            LanguageModelCompletionEvent::Text(chunk) => {
1042                                if let Some(last_message) = thread.messages.last_mut() {
1043                                    if last_message.role == Role::Assistant {
1044                                        last_message.push_text(&chunk);
1045                                        cx.emit(ThreadEvent::StreamedAssistantText(
1046                                            last_message.id,
1047                                            chunk,
1048                                        ));
1049                                    } else {
1050                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1051                                        // of a new Assistant response.
1052                                        //
1053                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1054                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1055                                        thread.insert_message(
1056                                            Role::Assistant,
1057                                            vec![MessageSegment::Text(chunk.to_string())],
1058                                            cx,
1059                                        );
1060                                    };
1061                                }
1062                            }
1063                            LanguageModelCompletionEvent::Thinking(chunk) => {
1064                                if let Some(last_message) = thread.messages.last_mut() {
1065                                    if last_message.role == Role::Assistant {
1066                                        last_message.push_thinking(&chunk);
1067                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1068                                            last_message.id,
1069                                            chunk,
1070                                        ));
1071                                    } else {
1072                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1073                                        // of a new Assistant response.
1074                                        //
1075                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1076                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1077                                        thread.insert_message(
1078                                            Role::Assistant,
1079                                            vec![MessageSegment::Thinking(chunk.to_string())],
1080                                            cx,
1081                                        );
1082                                    };
1083                                }
1084                            }
1085                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1086                                let last_assistant_message = thread
1087                                    .messages
1088                                    .iter_mut()
1089                                    .rfind(|message| message.role == Role::Assistant);
1090
1091                                let last_assistant_message_id =
1092                                    if let Some(message) = last_assistant_message {
1093                                        if let Some(segment) = message.segments.first_mut() {
1094                                            let text = segment.text_mut();
1095                                            if text.is_empty() {
1096                                                text.push_str("Using tool...");
1097                                            }
1098                                        } else {
1099                                            message.segments.push(MessageSegment::Text(
1100                                                "Using tool...".to_string(),
1101                                            ));
1102                                        }
1103
1104                                        message.id
1105                                    } else {
1106                                        thread.insert_message(
1107                                            Role::Assistant,
1108                                            vec![MessageSegment::Text("Using tool...".to_string())],
1109                                            cx,
1110                                        )
1111                                    };
1112                                thread.tool_use.request_tool_use(
1113                                    last_assistant_message_id,
1114                                    tool_use,
1115                                    cx,
1116                                );
1117                            }
1118                        }
1119
1120                        thread.touch_updated_at();
1121                        cx.emit(ThreadEvent::StreamedCompletion);
1122                        cx.notify();
1123                    })?;
1124
1125                    smol::future::yield_now().await;
1126                }
1127
1128                thread.update(cx, |thread, cx| {
1129                    thread
1130                        .pending_completions
1131                        .retain(|completion| completion.id != pending_completion_id);
1132
1133                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1134                        thread.summarize(cx);
1135                    }
1136                })?;
1137
1138                anyhow::Ok(stop_reason)
1139            };
1140
1141            let result = stream_completion.await;
1142
1143            thread
1144                .update(cx, |thread, cx| {
1145                    thread.finalize_pending_checkpoint(cx);
1146                    match result.as_ref() {
1147                        Ok(stop_reason) => match stop_reason {
1148                            StopReason::ToolUse => {
1149                                cx.emit(ThreadEvent::UsePendingTools);
1150                            }
1151                            StopReason::EndTurn => {}
1152                            StopReason::MaxTokens => {}
1153                        },
1154                        Err(error) => {
1155                            if error.is::<PaymentRequiredError>() {
1156                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1157                            } else if error.is::<MaxMonthlySpendReachedError>() {
1158                                cx.emit(ThreadEvent::ShowError(
1159                                    ThreadError::MaxMonthlySpendReached,
1160                                ));
1161                            } else {
1162                                let error_message = error
1163                                    .chain()
1164                                    .map(|err| err.to_string())
1165                                    .collect::<Vec<_>>()
1166                                    .join("\n");
1167                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1168                                    header: "Error interacting with language model".into(),
1169                                    message: SharedString::from(error_message.clone()),
1170                                }));
1171                            }
1172
1173                            thread.cancel_last_completion(cx);
1174                        }
1175                    }
1176                    cx.emit(ThreadEvent::DoneStreaming);
1177
1178                    if let Ok(initial_usage) = initial_token_usage {
1179                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1180
1181                        telemetry::event!(
1182                            "Assistant Thread Completion",
1183                            thread_id = thread.id().to_string(),
1184                            model = model.telemetry_id(),
1185                            model_provider = model.provider_id().to_string(),
1186                            input_tokens = usage.input_tokens,
1187                            output_tokens = usage.output_tokens,
1188                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1189                            cache_read_input_tokens = usage.cache_read_input_tokens,
1190                        );
1191                    }
1192                })
1193                .ok();
1194        });
1195
1196        self.pending_completions.push(PendingCompletion {
1197            id: pending_completion_id,
1198            _task: task,
1199        });
1200    }
1201
1202    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1203        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
1204            return;
1205        };
1206        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
1207            return;
1208        };
1209
1210        if !provider.is_authenticated(cx) {
1211            return;
1212        }
1213
1214        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1215        request.messages.push(LanguageModelRequestMessage {
1216            role: Role::User,
1217            content: vec![
1218                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1219                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1220                 If the conversation is about a specific subject, include it in the title. \
1221                 Be descriptive. DO NOT speak in the first person."
1222                    .into(),
1223            ],
1224            cache: false,
1225        });
1226
1227        self.pending_summary = cx.spawn(async move |this, cx| {
1228            async move {
1229                let stream = model.stream_completion_text(request, &cx);
1230                let mut messages = stream.await?;
1231
1232                let mut new_summary = String::new();
1233                while let Some(message) = messages.stream.next().await {
1234                    let text = message?;
1235                    let mut lines = text.lines();
1236                    new_summary.extend(lines.next());
1237
1238                    // Stop if the LLM generated multiple lines.
1239                    if lines.next().is_some() {
1240                        break;
1241                    }
1242                }
1243
1244                this.update(cx, |this, cx| {
1245                    if !new_summary.is_empty() {
1246                        this.summary = Some(new_summary.into());
1247                    }
1248
1249                    cx.emit(ThreadEvent::SummaryChanged);
1250                })?;
1251
1252                anyhow::Ok(())
1253            }
1254            .log_err()
1255            .await
1256        });
1257    }
1258
1259    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1260        let last_message_id = self.messages.last().map(|message| message.id)?;
1261
1262        match &self.detailed_summary_state {
1263            DetailedSummaryState::Generating { message_id, .. }
1264            | DetailedSummaryState::Generated { message_id, .. }
1265                if *message_id == last_message_id =>
1266            {
1267                // Already up-to-date
1268                return None;
1269            }
1270            _ => {}
1271        }
1272
1273        let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
1274        let model = LanguageModelRegistry::read_global(cx).active_model()?;
1275
1276        if !provider.is_authenticated(cx) {
1277            return None;
1278        }
1279
1280        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1281
1282        request.messages.push(LanguageModelRequestMessage {
1283            role: Role::User,
1284            content: vec![
1285                "Generate a detailed summary of this conversation. Include:\n\
1286                1. A brief overview of what was discussed\n\
1287                2. Key facts or information discovered\n\
1288                3. Outcomes or conclusions reached\n\
1289                4. Any action items or next steps if any\n\
1290                Format it in Markdown with headings and bullet points."
1291                    .into(),
1292            ],
1293            cache: false,
1294        });
1295
1296        let task = cx.spawn(async move |thread, cx| {
1297            let stream = model.stream_completion_text(request, &cx);
1298            let Some(mut messages) = stream.await.log_err() else {
1299                thread
1300                    .update(cx, |this, _cx| {
1301                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1302                    })
1303                    .log_err();
1304
1305                return;
1306            };
1307
1308            let mut new_detailed_summary = String::new();
1309
1310            while let Some(chunk) = messages.stream.next().await {
1311                if let Some(chunk) = chunk.log_err() {
1312                    new_detailed_summary.push_str(&chunk);
1313                }
1314            }
1315
1316            thread
1317                .update(cx, |this, _cx| {
1318                    this.detailed_summary_state = DetailedSummaryState::Generated {
1319                        text: new_detailed_summary.into(),
1320                        message_id: last_message_id,
1321                    };
1322                })
1323                .log_err();
1324        });
1325
1326        self.detailed_summary_state = DetailedSummaryState::Generating {
1327            message_id: last_message_id,
1328        };
1329
1330        Some(task)
1331    }
1332
1333    pub fn is_generating_detailed_summary(&self) -> bool {
1334        matches!(
1335            self.detailed_summary_state,
1336            DetailedSummaryState::Generating { .. }
1337        )
1338    }
1339
1340    pub fn use_pending_tools(
1341        &mut self,
1342        cx: &mut Context<Self>,
1343    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1344        let request = self.to_completion_request(RequestKind::Chat, cx);
1345        let messages = Arc::new(request.messages);
1346        let pending_tool_uses = self
1347            .tool_use
1348            .pending_tool_uses()
1349            .into_iter()
1350            .filter(|tool_use| tool_use.status.is_idle())
1351            .cloned()
1352            .collect::<Vec<_>>();
1353
1354        for tool_use in pending_tool_uses.iter() {
1355            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1356                if tool.needs_confirmation()
1357                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1358                {
1359                    self.tool_use.confirm_tool_use(
1360                        tool_use.id.clone(),
1361                        tool_use.ui_text.clone(),
1362                        tool_use.input.clone(),
1363                        messages.clone(),
1364                        tool,
1365                    );
1366                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1367                } else {
1368                    self.run_tool(
1369                        tool_use.id.clone(),
1370                        tool_use.ui_text.clone(),
1371                        tool_use.input.clone(),
1372                        &messages,
1373                        tool,
1374                        cx,
1375                    );
1376                }
1377            }
1378        }
1379
1380        pending_tool_uses
1381    }
1382
1383    pub fn run_tool(
1384        &mut self,
1385        tool_use_id: LanguageModelToolUseId,
1386        ui_text: impl Into<SharedString>,
1387        input: serde_json::Value,
1388        messages: &[LanguageModelRequestMessage],
1389        tool: Arc<dyn Tool>,
1390        cx: &mut Context<Thread>,
1391    ) {
1392        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1393        self.tool_use
1394            .run_pending_tool(tool_use_id, ui_text.into(), task);
1395    }
1396
1397    fn spawn_tool_use(
1398        &mut self,
1399        tool_use_id: LanguageModelToolUseId,
1400        messages: &[LanguageModelRequestMessage],
1401        input: serde_json::Value,
1402        tool: Arc<dyn Tool>,
1403        cx: &mut Context<Thread>,
1404    ) -> Task<()> {
1405        let tool_name: Arc<str> = tool.name().into();
1406
1407        let run_tool = if self.tools.is_disabled(&tool.source(), &tool_name) {
1408            Task::ready(Err(anyhow!("tool is disabled: {tool_name}")))
1409        } else {
1410            tool.run(
1411                input,
1412                messages,
1413                self.project.clone(),
1414                self.action_log.clone(),
1415                cx,
1416            )
1417        };
1418
1419        cx.spawn({
1420            async move |thread: WeakEntity<Thread>, cx| {
1421                let output = run_tool.await;
1422
1423                thread
1424                    .update(cx, |thread, cx| {
1425                        let pending_tool_use = thread.tool_use.insert_tool_output(
1426                            tool_use_id.clone(),
1427                            tool_name,
1428                            output,
1429                        );
1430
1431                        cx.emit(ThreadEvent::ToolFinished {
1432                            tool_use_id,
1433                            pending_tool_use,
1434                            canceled: false,
1435                        });
1436                    })
1437                    .ok();
1438            }
1439        })
1440    }
1441
1442    pub fn attach_tool_results(
1443        &mut self,
1444        updated_context: Vec<AssistantContext>,
1445        cx: &mut Context<Self>,
1446    ) {
1447        self.context.extend(
1448            updated_context
1449                .into_iter()
1450                .map(|context| (context.id(), context)),
1451        );
1452
1453        // Insert a user message to contain the tool results.
1454        self.insert_user_message(
1455            // TODO: Sending up a user message without any content results in the model sending back
1456            // responses that also don't have any content. We currently don't handle this case well,
1457            // so for now we provide some text to keep the model on track.
1458            "Here are the tool results.",
1459            Vec::new(),
1460            None,
1461            cx,
1462        );
1463    }
1464
1465    /// Cancels the last pending completion, if there are any pending.
1466    ///
1467    /// Returns whether a completion was canceled.
1468    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1469        let canceled = if self.pending_completions.pop().is_some() {
1470            true
1471        } else {
1472            let mut canceled = false;
1473            for pending_tool_use in self.tool_use.cancel_pending() {
1474                canceled = true;
1475                cx.emit(ThreadEvent::ToolFinished {
1476                    tool_use_id: pending_tool_use.id.clone(),
1477                    pending_tool_use: Some(pending_tool_use),
1478                    canceled: true,
1479                });
1480            }
1481            canceled
1482        };
1483        self.finalize_pending_checkpoint(cx);
1484        canceled
1485    }
1486
1487    /// Returns the feedback given to the thread, if any.
1488    pub fn feedback(&self) -> Option<ThreadFeedback> {
1489        self.feedback
1490    }
1491
1492    /// Reports feedback about the thread and stores it in our telemetry backend.
1493    pub fn report_feedback(
1494        &mut self,
1495        feedback: ThreadFeedback,
1496        cx: &mut Context<Self>,
1497    ) -> Task<Result<()>> {
1498        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1499        let serialized_thread = self.serialize(cx);
1500        let thread_id = self.id().clone();
1501        let client = self.project.read(cx).client();
1502        self.feedback = Some(feedback);
1503        cx.notify();
1504
1505        cx.background_spawn(async move {
1506            let final_project_snapshot = final_project_snapshot.await;
1507            let serialized_thread = serialized_thread.await?;
1508            let thread_data =
1509                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1510
1511            let rating = match feedback {
1512                ThreadFeedback::Positive => "positive",
1513                ThreadFeedback::Negative => "negative",
1514            };
1515            telemetry::event!(
1516                "Assistant Thread Rated",
1517                rating,
1518                thread_id,
1519                thread_data,
1520                final_project_snapshot
1521            );
1522            client.telemetry().flush_events();
1523
1524            Ok(())
1525        })
1526    }
1527
1528    /// Create a snapshot of the current project state including git information and unsaved buffers.
1529    fn project_snapshot(
1530        project: Entity<Project>,
1531        cx: &mut Context<Self>,
1532    ) -> Task<Arc<ProjectSnapshot>> {
1533        let git_store = project.read(cx).git_store().clone();
1534        let worktree_snapshots: Vec<_> = project
1535            .read(cx)
1536            .visible_worktrees(cx)
1537            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1538            .collect();
1539
1540        cx.spawn(async move |_, cx| {
1541            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1542
1543            let mut unsaved_buffers = Vec::new();
1544            cx.update(|app_cx| {
1545                let buffer_store = project.read(app_cx).buffer_store();
1546                for buffer_handle in buffer_store.read(app_cx).buffers() {
1547                    let buffer = buffer_handle.read(app_cx);
1548                    if buffer.is_dirty() {
1549                        if let Some(file) = buffer.file() {
1550                            let path = file.path().to_string_lossy().to_string();
1551                            unsaved_buffers.push(path);
1552                        }
1553                    }
1554                }
1555            })
1556            .ok();
1557
1558            Arc::new(ProjectSnapshot {
1559                worktree_snapshots,
1560                unsaved_buffer_paths: unsaved_buffers,
1561                timestamp: Utc::now(),
1562            })
1563        })
1564    }
1565
1566    fn worktree_snapshot(
1567        worktree: Entity<project::Worktree>,
1568        git_store: Entity<GitStore>,
1569        cx: &App,
1570    ) -> Task<WorktreeSnapshot> {
1571        cx.spawn(async move |cx| {
1572            // Get worktree path and snapshot
1573            let worktree_info = cx.update(|app_cx| {
1574                let worktree = worktree.read(app_cx);
1575                let path = worktree.abs_path().to_string_lossy().to_string();
1576                let snapshot = worktree.snapshot();
1577                (path, snapshot)
1578            });
1579
1580            let Ok((worktree_path, _snapshot)) = worktree_info else {
1581                return WorktreeSnapshot {
1582                    worktree_path: String::new(),
1583                    git_state: None,
1584                };
1585            };
1586
1587            let git_state = git_store
1588                .update(cx, |git_store, cx| {
1589                    git_store
1590                        .repositories()
1591                        .values()
1592                        .find(|repo| {
1593                            repo.read(cx)
1594                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1595                                .is_some()
1596                        })
1597                        .cloned()
1598                })
1599                .ok()
1600                .flatten()
1601                .map(|repo| {
1602                    repo.read_with(cx, |repo, _| {
1603                        let current_branch =
1604                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1605                        repo.send_job(|state, _| async move {
1606                            let RepositoryState::Local { backend, .. } = state else {
1607                                return GitState {
1608                                    remote_url: None,
1609                                    head_sha: None,
1610                                    current_branch,
1611                                    diff: None,
1612                                };
1613                            };
1614
1615                            let remote_url = backend.remote_url("origin");
1616                            let head_sha = backend.head_sha();
1617                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1618
1619                            GitState {
1620                                remote_url,
1621                                head_sha,
1622                                current_branch,
1623                                diff,
1624                            }
1625                        })
1626                    })
1627                });
1628
1629            let git_state = match git_state {
1630                Some(git_state) => match git_state.ok() {
1631                    Some(git_state) => git_state.await.ok(),
1632                    None => None,
1633                },
1634                None => None,
1635            };
1636
1637            WorktreeSnapshot {
1638                worktree_path,
1639                git_state,
1640            }
1641        })
1642    }
1643
1644    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1645        let mut markdown = Vec::new();
1646
1647        if let Some(summary) = self.summary() {
1648            writeln!(markdown, "# {summary}\n")?;
1649        };
1650
1651        for message in self.messages() {
1652            writeln!(
1653                markdown,
1654                "## {role}\n",
1655                role = match message.role {
1656                    Role::User => "User",
1657                    Role::Assistant => "Assistant",
1658                    Role::System => "System",
1659                }
1660            )?;
1661            for segment in &message.segments {
1662                match segment {
1663                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1664                    MessageSegment::Thinking(text) => {
1665                        writeln!(markdown, "<think>{}</think>\n", text)?
1666                    }
1667                }
1668            }
1669
1670            for tool_use in self.tool_uses_for_message(message.id, cx) {
1671                writeln!(
1672                    markdown,
1673                    "**Use Tool: {} ({})**",
1674                    tool_use.name, tool_use.id
1675                )?;
1676                writeln!(markdown, "```json")?;
1677                writeln!(
1678                    markdown,
1679                    "{}",
1680                    serde_json::to_string_pretty(&tool_use.input)?
1681                )?;
1682                writeln!(markdown, "```")?;
1683            }
1684
1685            for tool_result in self.tool_results_for_message(message.id) {
1686                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1687                if tool_result.is_error {
1688                    write!(markdown, " (Error)")?;
1689                }
1690
1691                writeln!(markdown, "**\n")?;
1692                writeln!(markdown, "{}", tool_result.content)?;
1693            }
1694        }
1695
1696        Ok(String::from_utf8_lossy(&markdown).to_string())
1697    }
1698
1699    pub fn keep_edits_in_range(
1700        &mut self,
1701        buffer: Entity<language::Buffer>,
1702        buffer_range: Range<language::Anchor>,
1703        cx: &mut Context<Self>,
1704    ) {
1705        self.action_log.update(cx, |action_log, cx| {
1706            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1707        });
1708    }
1709
1710    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1711        self.action_log
1712            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1713    }
1714
1715    pub fn action_log(&self) -> &Entity<ActionLog> {
1716        &self.action_log
1717    }
1718
1719    pub fn project(&self) -> &Entity<Project> {
1720        &self.project
1721    }
1722
1723    pub fn cumulative_token_usage(&self) -> TokenUsage {
1724        self.cumulative_token_usage.clone()
1725    }
1726
1727    pub fn is_getting_too_long(&self, cx: &App) -> bool {
1728        let model_registry = LanguageModelRegistry::read_global(cx);
1729        let Some(model) = model_registry.active_model() else {
1730            return false;
1731        };
1732
1733        let max_tokens = model.max_token_count();
1734
1735        let current_usage =
1736            self.cumulative_token_usage.input_tokens + self.cumulative_token_usage.output_tokens;
1737
1738        #[cfg(debug_assertions)]
1739        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1740            .unwrap_or("0.9".to_string())
1741            .parse()
1742            .unwrap();
1743        #[cfg(not(debug_assertions))]
1744        let warning_threshold: f32 = 0.9;
1745
1746        current_usage as f32 >= (max_tokens as f32 * warning_threshold)
1747    }
1748
1749    pub fn deny_tool_use(
1750        &mut self,
1751        tool_use_id: LanguageModelToolUseId,
1752        tool_name: Arc<str>,
1753        cx: &mut Context<Self>,
1754    ) {
1755        let err = Err(anyhow::anyhow!(
1756            "Permission to run tool action denied by user"
1757        ));
1758
1759        self.tool_use
1760            .insert_tool_output(tool_use_id.clone(), tool_name, err);
1761
1762        cx.emit(ThreadEvent::ToolFinished {
1763            tool_use_id,
1764            pending_tool_use: None,
1765            canceled: true,
1766        });
1767    }
1768}
1769
1770#[derive(Debug, Clone)]
1771pub enum ThreadError {
1772    PaymentRequired,
1773    MaxMonthlySpendReached,
1774    Message {
1775        header: SharedString,
1776        message: SharedString,
1777    },
1778}
1779
1780#[derive(Debug, Clone)]
1781pub enum ThreadEvent {
1782    ShowError(ThreadError),
1783    StreamedCompletion,
1784    StreamedAssistantText(MessageId, String),
1785    StreamedAssistantThinking(MessageId, String),
1786    DoneStreaming,
1787    MessageAdded(MessageId),
1788    MessageEdited(MessageId),
1789    MessageDeleted(MessageId),
1790    SummaryChanged,
1791    UsePendingTools,
1792    ToolFinished {
1793        #[allow(unused)]
1794        tool_use_id: LanguageModelToolUseId,
1795        /// The pending tool use that corresponds to this tool.
1796        pending_tool_use: Option<PendingToolUse>,
1797        /// Whether the tool was canceled by the user.
1798        canceled: bool,
1799    },
1800    CheckpointChanged,
1801    ToolConfirmationNeeded,
1802}
1803
1804impl EventEmitter<ThreadEvent> for Thread {}
1805
1806struct PendingCompletion {
1807    id: usize,
1808    _task: Task<()>,
1809}