thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::ops::Range;
   4use std::sync::Arc;
   5
   6use anyhow::{Context as _, Result};
   7use assistant_settings::AssistantSettings;
   8use assistant_tool::{ActionLog, Tool, ToolWorkingSet};
   9use chrono::{DateTime, Utc};
  10use collections::{BTreeMap, HashMap, HashSet};
  11use fs::Fs;
  12use futures::future::Shared;
  13use futures::{FutureExt, StreamExt as _};
  14use git::repository::DiffType;
  15use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task, WeakEntity};
  16use language_model::{
  17    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  18    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  19    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  20    Role, StopReason, TokenUsage,
  21};
  22use project::git_store::{GitStore, GitStoreCheckpoint, RepositoryState};
  23use project::{Project, Worktree};
  24use prompt_store::{
  25    AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt,
  26};
  27use schemars::JsonSchema;
  28use serde::{Deserialize, Serialize};
  29use settings::Settings;
  30use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc};
  31use uuid::Uuid;
  32
  33use crate::context::{AssistantContext, ContextId, attach_context_to_message};
  34use crate::thread_store::{
  35    SerializedMessage, SerializedMessageSegment, SerializedThread, SerializedToolResult,
  36    SerializedToolUse,
  37};
  38use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
  39
  40#[derive(Debug, Clone, Copy)]
  41pub enum RequestKind {
  42    Chat,
  43    /// Used when summarizing a thread.
  44    Summarize,
  45}
  46
  47#[derive(
  48    Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema,
  49)]
  50pub struct ThreadId(Arc<str>);
  51
  52impl ThreadId {
  53    pub fn new() -> Self {
  54        Self(Uuid::new_v4().to_string().into())
  55    }
  56}
  57
  58impl std::fmt::Display for ThreadId {
  59    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  60        write!(f, "{}", self.0)
  61    }
  62}
  63
  64impl From<&str> for ThreadId {
  65    fn from(value: &str) -> Self {
  66        Self(value.into())
  67    }
  68}
  69
  70#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  71pub struct MessageId(pub(crate) usize);
  72
  73impl MessageId {
  74    fn post_inc(&mut self) -> Self {
  75        Self(post_inc(&mut self.0))
  76    }
  77}
  78
  79/// A message in a [`Thread`].
  80#[derive(Debug, Clone)]
  81pub struct Message {
  82    pub id: MessageId,
  83    pub role: Role,
  84    pub segments: Vec<MessageSegment>,
  85}
  86
  87impl Message {
  88    pub fn push_thinking(&mut self, text: &str) {
  89        if let Some(MessageSegment::Thinking(segment)) = self.segments.last_mut() {
  90            segment.push_str(text);
  91        } else {
  92            self.segments
  93                .push(MessageSegment::Thinking(text.to_string()));
  94        }
  95    }
  96
  97    pub fn push_text(&mut self, text: &str) {
  98        if let Some(MessageSegment::Text(segment)) = self.segments.last_mut() {
  99            segment.push_str(text);
 100        } else {
 101            self.segments.push(MessageSegment::Text(text.to_string()));
 102        }
 103    }
 104
 105    pub fn to_string(&self) -> String {
 106        let mut result = String::new();
 107        for segment in &self.segments {
 108            match segment {
 109                MessageSegment::Text(text) => result.push_str(text),
 110                MessageSegment::Thinking(text) => {
 111                    result.push_str("<think>");
 112                    result.push_str(text);
 113                    result.push_str("</think>");
 114                }
 115            }
 116        }
 117        result
 118    }
 119}
 120
 121#[derive(Debug, Clone)]
 122pub enum MessageSegment {
 123    Text(String),
 124    Thinking(String),
 125}
 126
 127impl MessageSegment {
 128    pub fn text_mut(&mut self) -> &mut String {
 129        match self {
 130            Self::Text(text) => text,
 131            Self::Thinking(text) => text,
 132        }
 133    }
 134}
 135
 136#[derive(Debug, Clone, Serialize, Deserialize)]
 137pub struct ProjectSnapshot {
 138    pub worktree_snapshots: Vec<WorktreeSnapshot>,
 139    pub unsaved_buffer_paths: Vec<String>,
 140    pub timestamp: DateTime<Utc>,
 141}
 142
 143#[derive(Debug, Clone, Serialize, Deserialize)]
 144pub struct WorktreeSnapshot {
 145    pub worktree_path: String,
 146    pub git_state: Option<GitState>,
 147}
 148
 149#[derive(Debug, Clone, Serialize, Deserialize)]
 150pub struct GitState {
 151    pub remote_url: Option<String>,
 152    pub head_sha: Option<String>,
 153    pub current_branch: Option<String>,
 154    pub diff: Option<String>,
 155}
 156
 157#[derive(Clone)]
 158pub struct ThreadCheckpoint {
 159    message_id: MessageId,
 160    git_checkpoint: GitStoreCheckpoint,
 161}
 162
 163#[derive(Copy, Clone, Debug)]
 164pub enum ThreadFeedback {
 165    Positive,
 166    Negative,
 167}
 168
 169pub enum LastRestoreCheckpoint {
 170    Pending {
 171        message_id: MessageId,
 172    },
 173    Error {
 174        message_id: MessageId,
 175        error: String,
 176    },
 177}
 178
 179impl LastRestoreCheckpoint {
 180    pub fn message_id(&self) -> MessageId {
 181        match self {
 182            LastRestoreCheckpoint::Pending { message_id } => *message_id,
 183            LastRestoreCheckpoint::Error { message_id, .. } => *message_id,
 184        }
 185    }
 186}
 187
 188#[derive(Clone, Debug, Default, Serialize, Deserialize)]
 189pub enum DetailedSummaryState {
 190    #[default]
 191    NotGenerated,
 192    Generating {
 193        message_id: MessageId,
 194    },
 195    Generated {
 196        text: SharedString,
 197        message_id: MessageId,
 198    },
 199}
 200
 201/// A thread of conversation with the LLM.
 202pub struct Thread {
 203    id: ThreadId,
 204    updated_at: DateTime<Utc>,
 205    summary: Option<SharedString>,
 206    pending_summary: Task<Option<()>>,
 207    detailed_summary_state: DetailedSummaryState,
 208    messages: Vec<Message>,
 209    next_message_id: MessageId,
 210    context: BTreeMap<ContextId, AssistantContext>,
 211    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 212    system_prompt_context: Option<AssistantSystemPromptContext>,
 213    checkpoints_by_message: HashMap<MessageId, ThreadCheckpoint>,
 214    completion_count: usize,
 215    pending_completions: Vec<PendingCompletion>,
 216    project: Entity<Project>,
 217    prompt_builder: Arc<PromptBuilder>,
 218    tools: Arc<ToolWorkingSet>,
 219    tool_use: ToolUseState,
 220    action_log: Entity<ActionLog>,
 221    last_restore_checkpoint: Option<LastRestoreCheckpoint>,
 222    pending_checkpoint: Option<ThreadCheckpoint>,
 223    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 224    cumulative_token_usage: TokenUsage,
 225    feedback: Option<ThreadFeedback>,
 226}
 227
 228impl Thread {
 229    pub fn new(
 230        project: Entity<Project>,
 231        tools: Arc<ToolWorkingSet>,
 232        prompt_builder: Arc<PromptBuilder>,
 233        cx: &mut Context<Self>,
 234    ) -> Self {
 235        Self {
 236            id: ThreadId::new(),
 237            updated_at: Utc::now(),
 238            summary: None,
 239            pending_summary: Task::ready(None),
 240            detailed_summary_state: DetailedSummaryState::NotGenerated,
 241            messages: Vec::new(),
 242            next_message_id: MessageId(0),
 243            context: BTreeMap::default(),
 244            context_by_message: HashMap::default(),
 245            system_prompt_context: None,
 246            checkpoints_by_message: HashMap::default(),
 247            completion_count: 0,
 248            pending_completions: Vec::new(),
 249            project: project.clone(),
 250            prompt_builder,
 251            tools: tools.clone(),
 252            last_restore_checkpoint: None,
 253            pending_checkpoint: None,
 254            tool_use: ToolUseState::new(tools.clone()),
 255            action_log: cx.new(|_| ActionLog::new()),
 256            initial_project_snapshot: {
 257                let project_snapshot = Self::project_snapshot(project, cx);
 258                cx.foreground_executor()
 259                    .spawn(async move { Some(project_snapshot.await) })
 260                    .shared()
 261            },
 262            cumulative_token_usage: TokenUsage::default(),
 263            feedback: None,
 264        }
 265    }
 266
 267    pub fn deserialize(
 268        id: ThreadId,
 269        serialized: SerializedThread,
 270        project: Entity<Project>,
 271        tools: Arc<ToolWorkingSet>,
 272        prompt_builder: Arc<PromptBuilder>,
 273        cx: &mut Context<Self>,
 274    ) -> Self {
 275        let next_message_id = MessageId(
 276            serialized
 277                .messages
 278                .last()
 279                .map(|message| message.id.0 + 1)
 280                .unwrap_or(0),
 281        );
 282        let tool_use =
 283            ToolUseState::from_serialized_messages(tools.clone(), &serialized.messages, |_| true);
 284
 285        Self {
 286            id,
 287            updated_at: serialized.updated_at,
 288            summary: Some(serialized.summary),
 289            pending_summary: Task::ready(None),
 290            detailed_summary_state: serialized.detailed_summary_state,
 291            messages: serialized
 292                .messages
 293                .into_iter()
 294                .map(|message| Message {
 295                    id: message.id,
 296                    role: message.role,
 297                    segments: message
 298                        .segments
 299                        .into_iter()
 300                        .map(|segment| match segment {
 301                            SerializedMessageSegment::Text { text } => MessageSegment::Text(text),
 302                            SerializedMessageSegment::Thinking { text } => {
 303                                MessageSegment::Thinking(text)
 304                            }
 305                        })
 306                        .collect(),
 307                })
 308                .collect(),
 309            next_message_id,
 310            context: BTreeMap::default(),
 311            context_by_message: HashMap::default(),
 312            system_prompt_context: None,
 313            checkpoints_by_message: HashMap::default(),
 314            completion_count: 0,
 315            pending_completions: Vec::new(),
 316            last_restore_checkpoint: None,
 317            pending_checkpoint: None,
 318            project,
 319            prompt_builder,
 320            tools,
 321            tool_use,
 322            action_log: cx.new(|_| ActionLog::new()),
 323            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 324            cumulative_token_usage: serialized.cumulative_token_usage,
 325            feedback: None,
 326        }
 327    }
 328
 329    pub fn id(&self) -> &ThreadId {
 330        &self.id
 331    }
 332
 333    pub fn is_empty(&self) -> bool {
 334        self.messages.is_empty()
 335    }
 336
 337    pub fn updated_at(&self) -> DateTime<Utc> {
 338        self.updated_at
 339    }
 340
 341    pub fn touch_updated_at(&mut self) {
 342        self.updated_at = Utc::now();
 343    }
 344
 345    pub fn summary(&self) -> Option<SharedString> {
 346        self.summary.clone()
 347    }
 348
 349    pub fn summary_or_default(&self) -> SharedString {
 350        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 351        self.summary.clone().unwrap_or(DEFAULT)
 352    }
 353
 354    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 355        self.summary = Some(summary.into());
 356        cx.emit(ThreadEvent::SummaryChanged);
 357    }
 358
 359    pub fn latest_detailed_summary_or_text(&self) -> SharedString {
 360        self.latest_detailed_summary()
 361            .unwrap_or_else(|| self.text().into())
 362    }
 363
 364    fn latest_detailed_summary(&self) -> Option<SharedString> {
 365        if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state {
 366            Some(text.clone())
 367        } else {
 368            None
 369        }
 370    }
 371
 372    pub fn message(&self, id: MessageId) -> Option<&Message> {
 373        self.messages.iter().find(|message| message.id == id)
 374    }
 375
 376    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 377        self.messages.iter()
 378    }
 379
 380    pub fn is_generating(&self) -> bool {
 381        !self.pending_completions.is_empty() || !self.all_tools_finished()
 382    }
 383
 384    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 385        &self.tools
 386    }
 387
 388    pub fn pending_tool(&self, id: &LanguageModelToolUseId) -> Option<&PendingToolUse> {
 389        self.tool_use
 390            .pending_tool_uses()
 391            .into_iter()
 392            .find(|tool_use| &tool_use.id == id)
 393    }
 394
 395    pub fn tools_needing_confirmation(&self) -> impl Iterator<Item = &PendingToolUse> {
 396        self.tool_use
 397            .pending_tool_uses()
 398            .into_iter()
 399            .filter(|tool_use| tool_use.status.needs_confirmation())
 400    }
 401
 402    pub fn has_pending_tool_uses(&self) -> bool {
 403        !self.tool_use.pending_tool_uses().is_empty()
 404    }
 405
 406    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 407        self.checkpoints_by_message.get(&id).cloned()
 408    }
 409
 410    pub fn restore_checkpoint(
 411        &mut self,
 412        checkpoint: ThreadCheckpoint,
 413        cx: &mut Context<Self>,
 414    ) -> Task<Result<()>> {
 415        self.last_restore_checkpoint = Some(LastRestoreCheckpoint::Pending {
 416            message_id: checkpoint.message_id,
 417        });
 418        cx.emit(ThreadEvent::CheckpointChanged);
 419        cx.notify();
 420
 421        let project = self.project.read(cx);
 422        let restore = project
 423            .git_store()
 424            .read(cx)
 425            .restore_checkpoint(checkpoint.git_checkpoint.clone(), cx);
 426        cx.spawn(async move |this, cx| {
 427            let result = restore.await;
 428            this.update(cx, |this, cx| {
 429                if let Err(err) = result.as_ref() {
 430                    this.last_restore_checkpoint = Some(LastRestoreCheckpoint::Error {
 431                        message_id: checkpoint.message_id,
 432                        error: err.to_string(),
 433                    });
 434                } else {
 435                    this.truncate(checkpoint.message_id, cx);
 436                    this.last_restore_checkpoint = None;
 437                }
 438                this.pending_checkpoint = None;
 439                cx.emit(ThreadEvent::CheckpointChanged);
 440                cx.notify();
 441            })?;
 442            result
 443        })
 444    }
 445
 446    fn finalize_pending_checkpoint(&mut self, cx: &mut Context<Self>) {
 447        let pending_checkpoint = if self.is_generating() {
 448            return;
 449        } else if let Some(checkpoint) = self.pending_checkpoint.take() {
 450            checkpoint
 451        } else {
 452            return;
 453        };
 454
 455        let git_store = self.project.read(cx).git_store().clone();
 456        let final_checkpoint = git_store.read(cx).checkpoint(cx);
 457        cx.spawn(async move |this, cx| match final_checkpoint.await {
 458            Ok(final_checkpoint) => {
 459                let equal = git_store
 460                    .read_with(cx, |store, cx| {
 461                        store.compare_checkpoints(
 462                            pending_checkpoint.git_checkpoint.clone(),
 463                            final_checkpoint.clone(),
 464                            cx,
 465                        )
 466                    })?
 467                    .await
 468                    .unwrap_or(false);
 469
 470                if equal {
 471                    git_store
 472                        .read_with(cx, |store, cx| {
 473                            store.delete_checkpoint(pending_checkpoint.git_checkpoint, cx)
 474                        })?
 475                        .detach();
 476                } else {
 477                    this.update(cx, |this, cx| {
 478                        this.insert_checkpoint(pending_checkpoint, cx)
 479                    })?;
 480                }
 481
 482                git_store
 483                    .read_with(cx, |store, cx| {
 484                        store.delete_checkpoint(final_checkpoint, cx)
 485                    })?
 486                    .detach();
 487
 488                Ok(())
 489            }
 490            Err(_) => this.update(cx, |this, cx| {
 491                this.insert_checkpoint(pending_checkpoint, cx)
 492            }),
 493        })
 494        .detach();
 495    }
 496
 497    fn insert_checkpoint(&mut self, checkpoint: ThreadCheckpoint, cx: &mut Context<Self>) {
 498        self.checkpoints_by_message
 499            .insert(checkpoint.message_id, checkpoint);
 500        cx.emit(ThreadEvent::CheckpointChanged);
 501        cx.notify();
 502    }
 503
 504    pub fn last_restore_checkpoint(&self) -> Option<&LastRestoreCheckpoint> {
 505        self.last_restore_checkpoint.as_ref()
 506    }
 507
 508    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 509        let Some(message_ix) = self
 510            .messages
 511            .iter()
 512            .rposition(|message| message.id == message_id)
 513        else {
 514            return;
 515        };
 516        for deleted_message in self.messages.drain(message_ix..) {
 517            self.context_by_message.remove(&deleted_message.id);
 518            self.checkpoints_by_message.remove(&deleted_message.id);
 519        }
 520        cx.notify();
 521    }
 522
 523    pub fn context_for_message(&self, id: MessageId) -> impl Iterator<Item = &AssistantContext> {
 524        self.context_by_message
 525            .get(&id)
 526            .into_iter()
 527            .flat_map(|context| {
 528                context
 529                    .iter()
 530                    .filter_map(|context_id| self.context.get(&context_id))
 531            })
 532    }
 533
 534    /// Returns whether all of the tool uses have finished running.
 535    pub fn all_tools_finished(&self) -> bool {
 536        // If the only pending tool uses left are the ones with errors, then
 537        // that means that we've finished running all of the pending tools.
 538        self.tool_use
 539            .pending_tool_uses()
 540            .iter()
 541            .all(|tool_use| tool_use.status.is_error())
 542    }
 543
 544    pub fn tool_uses_for_message(&self, id: MessageId, cx: &App) -> Vec<ToolUse> {
 545        self.tool_use.tool_uses_for_message(id, cx)
 546    }
 547
 548    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 549        self.tool_use.tool_results_for_message(id)
 550    }
 551
 552    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 553        self.tool_use.tool_result(id)
 554    }
 555
 556    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 557        self.tool_use.message_has_tool_results(message_id)
 558    }
 559
 560    pub fn insert_user_message(
 561        &mut self,
 562        text: impl Into<String>,
 563        context: Vec<AssistantContext>,
 564        git_checkpoint: Option<GitStoreCheckpoint>,
 565        cx: &mut Context<Self>,
 566    ) -> MessageId {
 567        let message_id =
 568            self.insert_message(Role::User, vec![MessageSegment::Text(text.into())], cx);
 569        let context_ids = context
 570            .iter()
 571            .map(|context| context.id())
 572            .collect::<Vec<_>>();
 573        self.context
 574            .extend(context.into_iter().map(|context| (context.id(), context)));
 575        self.context_by_message.insert(message_id, context_ids);
 576        if let Some(git_checkpoint) = git_checkpoint {
 577            self.pending_checkpoint = Some(ThreadCheckpoint {
 578                message_id,
 579                git_checkpoint,
 580            });
 581        }
 582        message_id
 583    }
 584
 585    pub fn insert_message(
 586        &mut self,
 587        role: Role,
 588        segments: Vec<MessageSegment>,
 589        cx: &mut Context<Self>,
 590    ) -> MessageId {
 591        let id = self.next_message_id.post_inc();
 592        self.messages.push(Message { id, role, segments });
 593        self.touch_updated_at();
 594        cx.emit(ThreadEvent::MessageAdded(id));
 595        id
 596    }
 597
 598    pub fn edit_message(
 599        &mut self,
 600        id: MessageId,
 601        new_role: Role,
 602        new_segments: Vec<MessageSegment>,
 603        cx: &mut Context<Self>,
 604    ) -> bool {
 605        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 606            return false;
 607        };
 608        message.role = new_role;
 609        message.segments = new_segments;
 610        self.touch_updated_at();
 611        cx.emit(ThreadEvent::MessageEdited(id));
 612        true
 613    }
 614
 615    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 616        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 617            return false;
 618        };
 619        self.messages.remove(index);
 620        self.context_by_message.remove(&id);
 621        self.touch_updated_at();
 622        cx.emit(ThreadEvent::MessageDeleted(id));
 623        true
 624    }
 625
 626    /// Returns the representation of this [`Thread`] in a textual form.
 627    ///
 628    /// This is the representation we use when attaching a thread as context to another thread.
 629    pub fn text(&self) -> String {
 630        let mut text = String::new();
 631
 632        for message in &self.messages {
 633            text.push_str(match message.role {
 634                language_model::Role::User => "User:",
 635                language_model::Role::Assistant => "Assistant:",
 636                language_model::Role::System => "System:",
 637            });
 638            text.push('\n');
 639
 640            for segment in &message.segments {
 641                match segment {
 642                    MessageSegment::Text(content) => text.push_str(content),
 643                    MessageSegment::Thinking(content) => {
 644                        text.push_str(&format!("<think>{}</think>", content))
 645                    }
 646                }
 647            }
 648            text.push('\n');
 649        }
 650
 651        text
 652    }
 653
 654    /// Serializes this thread into a format for storage or telemetry.
 655    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 656        let initial_project_snapshot = self.initial_project_snapshot.clone();
 657        cx.spawn(async move |this, cx| {
 658            let initial_project_snapshot = initial_project_snapshot.await;
 659            this.read_with(cx, |this, cx| SerializedThread {
 660                version: SerializedThread::VERSION.to_string(),
 661                summary: this.summary_or_default(),
 662                updated_at: this.updated_at(),
 663                messages: this
 664                    .messages()
 665                    .map(|message| SerializedMessage {
 666                        id: message.id,
 667                        role: message.role,
 668                        segments: message
 669                            .segments
 670                            .iter()
 671                            .map(|segment| match segment {
 672                                MessageSegment::Text(text) => {
 673                                    SerializedMessageSegment::Text { text: text.clone() }
 674                                }
 675                                MessageSegment::Thinking(text) => {
 676                                    SerializedMessageSegment::Thinking { text: text.clone() }
 677                                }
 678                            })
 679                            .collect(),
 680                        tool_uses: this
 681                            .tool_uses_for_message(message.id, cx)
 682                            .into_iter()
 683                            .map(|tool_use| SerializedToolUse {
 684                                id: tool_use.id,
 685                                name: tool_use.name,
 686                                input: tool_use.input,
 687                            })
 688                            .collect(),
 689                        tool_results: this
 690                            .tool_results_for_message(message.id)
 691                            .into_iter()
 692                            .map(|tool_result| SerializedToolResult {
 693                                tool_use_id: tool_result.tool_use_id.clone(),
 694                                is_error: tool_result.is_error,
 695                                content: tool_result.content.clone(),
 696                            })
 697                            .collect(),
 698                    })
 699                    .collect(),
 700                initial_project_snapshot,
 701                cumulative_token_usage: this.cumulative_token_usage.clone(),
 702                detailed_summary_state: this.detailed_summary_state.clone(),
 703            })
 704        })
 705    }
 706
 707    pub fn set_system_prompt_context(&mut self, context: AssistantSystemPromptContext) {
 708        self.system_prompt_context = Some(context);
 709    }
 710
 711    pub fn system_prompt_context(&self) -> &Option<AssistantSystemPromptContext> {
 712        &self.system_prompt_context
 713    }
 714
 715    pub fn load_system_prompt_context(
 716        &self,
 717        cx: &App,
 718    ) -> Task<(AssistantSystemPromptContext, Option<ThreadError>)> {
 719        let project = self.project.read(cx);
 720        let tasks = project
 721            .visible_worktrees(cx)
 722            .map(|worktree| {
 723                Self::load_worktree_info_for_system_prompt(
 724                    project.fs().clone(),
 725                    worktree.read(cx),
 726                    cx,
 727                )
 728            })
 729            .collect::<Vec<_>>();
 730
 731        cx.spawn(async |_cx| {
 732            let results = futures::future::join_all(tasks).await;
 733            let mut first_err = None;
 734            let worktrees = results
 735                .into_iter()
 736                .map(|(worktree, err)| {
 737                    if first_err.is_none() && err.is_some() {
 738                        first_err = err;
 739                    }
 740                    worktree
 741                })
 742                .collect::<Vec<_>>();
 743            (AssistantSystemPromptContext::new(worktrees), first_err)
 744        })
 745    }
 746
 747    fn load_worktree_info_for_system_prompt(
 748        fs: Arc<dyn Fs>,
 749        worktree: &Worktree,
 750        cx: &App,
 751    ) -> Task<(WorktreeInfoForSystemPrompt, Option<ThreadError>)> {
 752        let root_name = worktree.root_name().into();
 753        let abs_path = worktree.abs_path();
 754
 755        // Note that Cline supports `.clinerules` being a directory, but that is not currently
 756        // supported. This doesn't seem to occur often in GitHub repositories.
 757        const RULES_FILE_NAMES: [&'static str; 6] = [
 758            ".rules",
 759            ".cursorrules",
 760            ".windsurfrules",
 761            ".clinerules",
 762            ".github/copilot-instructions.md",
 763            "CLAUDE.md",
 764        ];
 765        let selected_rules_file = RULES_FILE_NAMES
 766            .into_iter()
 767            .filter_map(|name| {
 768                worktree
 769                    .entry_for_path(name)
 770                    .filter(|entry| entry.is_file())
 771                    .map(|entry| (entry.path.clone(), worktree.absolutize(&entry.path)))
 772            })
 773            .next();
 774
 775        if let Some((rel_rules_path, abs_rules_path)) = selected_rules_file {
 776            cx.spawn(async move |_| {
 777                let rules_file_result = maybe!(async move {
 778                    let abs_rules_path = abs_rules_path?;
 779                    let text = fs.load(&abs_rules_path).await.with_context(|| {
 780                        format!("Failed to load assistant rules file {:?}", abs_rules_path)
 781                    })?;
 782                    anyhow::Ok(RulesFile {
 783                        rel_path: rel_rules_path,
 784                        abs_path: abs_rules_path.into(),
 785                        text: text.trim().to_string(),
 786                    })
 787                })
 788                .await;
 789                let (rules_file, rules_file_error) = match rules_file_result {
 790                    Ok(rules_file) => (Some(rules_file), None),
 791                    Err(err) => (
 792                        None,
 793                        Some(ThreadError::Message {
 794                            header: "Error loading rules file".into(),
 795                            message: format!("{err}").into(),
 796                        }),
 797                    ),
 798                };
 799                let worktree_info = WorktreeInfoForSystemPrompt {
 800                    root_name,
 801                    abs_path,
 802                    rules_file,
 803                };
 804                (worktree_info, rules_file_error)
 805            })
 806        } else {
 807            Task::ready((
 808                WorktreeInfoForSystemPrompt {
 809                    root_name,
 810                    abs_path,
 811                    rules_file: None,
 812                },
 813                None,
 814            ))
 815        }
 816    }
 817
 818    pub fn send_to_model(
 819        &mut self,
 820        model: Arc<dyn LanguageModel>,
 821        request_kind: RequestKind,
 822        cx: &mut Context<Self>,
 823    ) {
 824        let mut request = self.to_completion_request(request_kind, cx);
 825        if model.supports_tools() {
 826            request.tools = {
 827                let mut tools = Vec::new();
 828                tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 829                    LanguageModelRequestTool {
 830                        name: tool.name(),
 831                        description: tool.description(),
 832                        input_schema: tool.input_schema(model.tool_input_format()),
 833                    }
 834                }));
 835
 836                tools
 837            };
 838        }
 839
 840        self.stream_completion(request, model, cx);
 841    }
 842
 843    pub fn used_tools_since_last_user_message(&self) -> bool {
 844        for message in self.messages.iter().rev() {
 845            if self.tool_use.message_has_tool_results(message.id) {
 846                return true;
 847            } else if message.role == Role::User {
 848                return false;
 849            }
 850        }
 851
 852        false
 853    }
 854
 855    pub fn to_completion_request(
 856        &self,
 857        request_kind: RequestKind,
 858        cx: &App,
 859    ) -> LanguageModelRequest {
 860        let mut request = LanguageModelRequest {
 861            messages: vec![],
 862            tools: Vec::new(),
 863            stop: Vec::new(),
 864            temperature: None,
 865        };
 866
 867        if let Some(system_prompt_context) = self.system_prompt_context.as_ref() {
 868            if let Some(system_prompt) = self
 869                .prompt_builder
 870                .generate_assistant_system_prompt(system_prompt_context)
 871                .context("failed to generate assistant system prompt")
 872                .log_err()
 873            {
 874                request.messages.push(LanguageModelRequestMessage {
 875                    role: Role::System,
 876                    content: vec![MessageContent::Text(system_prompt)],
 877                    cache: true,
 878                });
 879            }
 880        } else {
 881            log::error!("system_prompt_context not set.")
 882        }
 883
 884        let mut referenced_context_ids = HashSet::default();
 885
 886        for message in &self.messages {
 887            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 888                referenced_context_ids.extend(context_ids);
 889            }
 890
 891            let mut request_message = LanguageModelRequestMessage {
 892                role: message.role,
 893                content: Vec::new(),
 894                cache: false,
 895            };
 896
 897            match request_kind {
 898                RequestKind::Chat => {
 899                    self.tool_use
 900                        .attach_tool_results(message.id, &mut request_message);
 901                }
 902                RequestKind::Summarize => {
 903                    // We don't care about tool use during summarization.
 904                    if self.tool_use.message_has_tool_results(message.id) {
 905                        continue;
 906                    }
 907                }
 908            }
 909
 910            if !message.segments.is_empty() {
 911                request_message
 912                    .content
 913                    .push(MessageContent::Text(message.to_string()));
 914            }
 915
 916            match request_kind {
 917                RequestKind::Chat => {
 918                    self.tool_use
 919                        .attach_tool_uses(message.id, &mut request_message);
 920                }
 921                RequestKind::Summarize => {
 922                    // We don't care about tool use during summarization.
 923                }
 924            };
 925
 926            request.messages.push(request_message);
 927        }
 928
 929        // Set a cache breakpoint at the second-to-last message.
 930        // https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
 931        let breakpoint_index = request.messages.len() - 2;
 932        for (index, message) in request.messages.iter_mut().enumerate() {
 933            message.cache = index == breakpoint_index;
 934        }
 935
 936        if !referenced_context_ids.is_empty() {
 937            let mut context_message = LanguageModelRequestMessage {
 938                role: Role::User,
 939                content: Vec::new(),
 940                cache: false,
 941            };
 942
 943            let referenced_context = referenced_context_ids
 944                .into_iter()
 945                .filter_map(|context_id| self.context.get(context_id));
 946            attach_context_to_message(&mut context_message, referenced_context, cx);
 947
 948            request.messages.push(context_message);
 949        }
 950
 951        self.attached_tracked_files_state(&mut request.messages, cx);
 952
 953        request
 954    }
 955
 956    fn attached_tracked_files_state(
 957        &self,
 958        messages: &mut Vec<LanguageModelRequestMessage>,
 959        cx: &App,
 960    ) {
 961        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 962
 963        let mut stale_message = String::new();
 964
 965        let action_log = self.action_log.read(cx);
 966
 967        for stale_file in action_log.stale_buffers(cx) {
 968            let Some(file) = stale_file.read(cx).file() else {
 969                continue;
 970            };
 971
 972            if stale_message.is_empty() {
 973                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
 974            }
 975
 976            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 977        }
 978
 979        let mut content = Vec::with_capacity(2);
 980
 981        if !stale_message.is_empty() {
 982            content.push(stale_message.into());
 983        }
 984
 985        if action_log.has_edited_files_since_project_diagnostics_check() {
 986            content.push(
 987                "\n\nWhen you're done making changes, make sure to check project diagnostics \
 988                and fix all errors AND warnings you introduced! \
 989                DO NOT mention you're going to do this until you're done."
 990                    .into(),
 991            );
 992        }
 993
 994        if !content.is_empty() {
 995            let context_message = LanguageModelRequestMessage {
 996                role: Role::User,
 997                content,
 998                cache: false,
 999            };
1000
1001            messages.push(context_message);
1002        }
1003    }
1004
1005    pub fn stream_completion(
1006        &mut self,
1007        request: LanguageModelRequest,
1008        model: Arc<dyn LanguageModel>,
1009        cx: &mut Context<Self>,
1010    ) {
1011        let pending_completion_id = post_inc(&mut self.completion_count);
1012
1013        let task = cx.spawn(async move |thread, cx| {
1014            let stream = model.stream_completion(request, &cx);
1015            let initial_token_usage =
1016                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
1017            let stream_completion = async {
1018                let mut events = stream.await?;
1019                let mut stop_reason = StopReason::EndTurn;
1020                let mut current_token_usage = TokenUsage::default();
1021
1022                while let Some(event) = events.next().await {
1023                    let event = event?;
1024
1025                    thread.update(cx, |thread, cx| {
1026                        match event {
1027                            LanguageModelCompletionEvent::StartMessage { .. } => {
1028                                thread.insert_message(
1029                                    Role::Assistant,
1030                                    vec![MessageSegment::Text(String::new())],
1031                                    cx,
1032                                );
1033                            }
1034                            LanguageModelCompletionEvent::Stop(reason) => {
1035                                stop_reason = reason;
1036                            }
1037                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
1038                                thread.cumulative_token_usage =
1039                                    thread.cumulative_token_usage.clone() + token_usage.clone()
1040                                        - current_token_usage.clone();
1041                                current_token_usage = token_usage;
1042                            }
1043                            LanguageModelCompletionEvent::Text(chunk) => {
1044                                if let Some(last_message) = thread.messages.last_mut() {
1045                                    if last_message.role == Role::Assistant {
1046                                        last_message.push_text(&chunk);
1047                                        cx.emit(ThreadEvent::StreamedAssistantText(
1048                                            last_message.id,
1049                                            chunk,
1050                                        ));
1051                                    } else {
1052                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1053                                        // of a new Assistant response.
1054                                        //
1055                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1056                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1057                                        thread.insert_message(
1058                                            Role::Assistant,
1059                                            vec![MessageSegment::Text(chunk.to_string())],
1060                                            cx,
1061                                        );
1062                                    };
1063                                }
1064                            }
1065                            LanguageModelCompletionEvent::Thinking(chunk) => {
1066                                if let Some(last_message) = thread.messages.last_mut() {
1067                                    if last_message.role == Role::Assistant {
1068                                        last_message.push_thinking(&chunk);
1069                                        cx.emit(ThreadEvent::StreamedAssistantThinking(
1070                                            last_message.id,
1071                                            chunk,
1072                                        ));
1073                                    } else {
1074                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
1075                                        // of a new Assistant response.
1076                                        //
1077                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
1078                                        // will result in duplicating the text of the chunk in the rendered Markdown.
1079                                        thread.insert_message(
1080                                            Role::Assistant,
1081                                            vec![MessageSegment::Thinking(chunk.to_string())],
1082                                            cx,
1083                                        );
1084                                    };
1085                                }
1086                            }
1087                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
1088                                let last_assistant_message = thread
1089                                    .messages
1090                                    .iter_mut()
1091                                    .rfind(|message| message.role == Role::Assistant);
1092
1093                                let last_assistant_message_id =
1094                                    if let Some(message) = last_assistant_message {
1095                                        if let Some(segment) = message.segments.first_mut() {
1096                                            let text = segment.text_mut();
1097                                            if text.is_empty() {
1098                                                text.push_str("Using tool...");
1099                                            }
1100                                        } else {
1101                                            message.segments.push(MessageSegment::Text(
1102                                                "Using tool...".to_string(),
1103                                            ));
1104                                        }
1105
1106                                        message.id
1107                                    } else {
1108                                        thread.insert_message(
1109                                            Role::Assistant,
1110                                            vec![MessageSegment::Text("Using tool...".to_string())],
1111                                            cx,
1112                                        )
1113                                    };
1114                                thread.tool_use.request_tool_use(
1115                                    last_assistant_message_id,
1116                                    tool_use,
1117                                    cx,
1118                                );
1119                            }
1120                        }
1121
1122                        thread.touch_updated_at();
1123                        cx.emit(ThreadEvent::StreamedCompletion);
1124                        cx.notify();
1125                    })?;
1126
1127                    smol::future::yield_now().await;
1128                }
1129
1130                thread.update(cx, |thread, cx| {
1131                    thread
1132                        .pending_completions
1133                        .retain(|completion| completion.id != pending_completion_id);
1134
1135                    if thread.summary.is_none() && thread.messages.len() >= 2 {
1136                        thread.summarize(cx);
1137                    }
1138                })?;
1139
1140                anyhow::Ok(stop_reason)
1141            };
1142
1143            let result = stream_completion.await;
1144
1145            thread
1146                .update(cx, |thread, cx| {
1147                    thread.finalize_pending_checkpoint(cx);
1148                    match result.as_ref() {
1149                        Ok(stop_reason) => match stop_reason {
1150                            StopReason::ToolUse => {
1151                                cx.emit(ThreadEvent::UsePendingTools);
1152                            }
1153                            StopReason::EndTurn => {}
1154                            StopReason::MaxTokens => {}
1155                        },
1156                        Err(error) => {
1157                            if error.is::<PaymentRequiredError>() {
1158                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
1159                            } else if error.is::<MaxMonthlySpendReachedError>() {
1160                                cx.emit(ThreadEvent::ShowError(
1161                                    ThreadError::MaxMonthlySpendReached,
1162                                ));
1163                            } else {
1164                                let error_message = error
1165                                    .chain()
1166                                    .map(|err| err.to_string())
1167                                    .collect::<Vec<_>>()
1168                                    .join("\n");
1169                                cx.emit(ThreadEvent::ShowError(ThreadError::Message {
1170                                    header: "Error interacting with language model".into(),
1171                                    message: SharedString::from(error_message.clone()),
1172                                }));
1173                            }
1174
1175                            thread.cancel_last_completion(cx);
1176                        }
1177                    }
1178                    cx.emit(ThreadEvent::DoneStreaming);
1179
1180                    if let Ok(initial_usage) = initial_token_usage {
1181                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
1182
1183                        telemetry::event!(
1184                            "Assistant Thread Completion",
1185                            thread_id = thread.id().to_string(),
1186                            model = model.telemetry_id(),
1187                            model_provider = model.provider_id().to_string(),
1188                            input_tokens = usage.input_tokens,
1189                            output_tokens = usage.output_tokens,
1190                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
1191                            cache_read_input_tokens = usage.cache_read_input_tokens,
1192                        );
1193                    }
1194                })
1195                .ok();
1196        });
1197
1198        self.pending_completions.push(PendingCompletion {
1199            id: pending_completion_id,
1200            _task: task,
1201        });
1202    }
1203
1204    pub fn summarize(&mut self, cx: &mut Context<Self>) {
1205        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
1206            return;
1207        };
1208        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
1209            return;
1210        };
1211
1212        if !provider.is_authenticated(cx) {
1213            return;
1214        }
1215
1216        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1217        request.messages.push(LanguageModelRequestMessage {
1218            role: Role::User,
1219            content: vec![
1220                "Generate a concise 3-7 word title for this conversation, omitting punctuation. \
1221                 Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`. \
1222                 If the conversation is about a specific subject, include it in the title. \
1223                 Be descriptive. DO NOT speak in the first person."
1224                    .into(),
1225            ],
1226            cache: false,
1227        });
1228
1229        self.pending_summary = cx.spawn(async move |this, cx| {
1230            async move {
1231                let stream = model.stream_completion_text(request, &cx);
1232                let mut messages = stream.await?;
1233
1234                let mut new_summary = String::new();
1235                while let Some(message) = messages.stream.next().await {
1236                    let text = message?;
1237                    let mut lines = text.lines();
1238                    new_summary.extend(lines.next());
1239
1240                    // Stop if the LLM generated multiple lines.
1241                    if lines.next().is_some() {
1242                        break;
1243                    }
1244                }
1245
1246                this.update(cx, |this, cx| {
1247                    if !new_summary.is_empty() {
1248                        this.summary = Some(new_summary.into());
1249                    }
1250
1251                    cx.emit(ThreadEvent::SummaryChanged);
1252                })?;
1253
1254                anyhow::Ok(())
1255            }
1256            .log_err()
1257            .await
1258        });
1259    }
1260
1261    pub fn generate_detailed_summary(&mut self, cx: &mut Context<Self>) -> Option<Task<()>> {
1262        let last_message_id = self.messages.last().map(|message| message.id)?;
1263
1264        match &self.detailed_summary_state {
1265            DetailedSummaryState::Generating { message_id, .. }
1266            | DetailedSummaryState::Generated { message_id, .. }
1267                if *message_id == last_message_id =>
1268            {
1269                // Already up-to-date
1270                return None;
1271            }
1272            _ => {}
1273        }
1274
1275        let provider = LanguageModelRegistry::read_global(cx).active_provider()?;
1276        let model = LanguageModelRegistry::read_global(cx).active_model()?;
1277
1278        if !provider.is_authenticated(cx) {
1279            return None;
1280        }
1281
1282        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
1283
1284        request.messages.push(LanguageModelRequestMessage {
1285            role: Role::User,
1286            content: vec![
1287                "Generate a detailed summary of this conversation. Include:\n\
1288                1. A brief overview of what was discussed\n\
1289                2. Key facts or information discovered\n\
1290                3. Outcomes or conclusions reached\n\
1291                4. Any action items or next steps if any\n\
1292                Format it in Markdown with headings and bullet points."
1293                    .into(),
1294            ],
1295            cache: false,
1296        });
1297
1298        let task = cx.spawn(async move |thread, cx| {
1299            let stream = model.stream_completion_text(request, &cx);
1300            let Some(mut messages) = stream.await.log_err() else {
1301                thread
1302                    .update(cx, |this, _cx| {
1303                        this.detailed_summary_state = DetailedSummaryState::NotGenerated;
1304                    })
1305                    .log_err();
1306
1307                return;
1308            };
1309
1310            let mut new_detailed_summary = String::new();
1311
1312            while let Some(chunk) = messages.stream.next().await {
1313                if let Some(chunk) = chunk.log_err() {
1314                    new_detailed_summary.push_str(&chunk);
1315                }
1316            }
1317
1318            thread
1319                .update(cx, |this, _cx| {
1320                    this.detailed_summary_state = DetailedSummaryState::Generated {
1321                        text: new_detailed_summary.into(),
1322                        message_id: last_message_id,
1323                    };
1324                })
1325                .log_err();
1326        });
1327
1328        self.detailed_summary_state = DetailedSummaryState::Generating {
1329            message_id: last_message_id,
1330        };
1331
1332        Some(task)
1333    }
1334
1335    pub fn is_generating_detailed_summary(&self) -> bool {
1336        matches!(
1337            self.detailed_summary_state,
1338            DetailedSummaryState::Generating { .. }
1339        )
1340    }
1341
1342    pub fn use_pending_tools(
1343        &mut self,
1344        cx: &mut Context<Self>,
1345    ) -> impl IntoIterator<Item = PendingToolUse> + use<> {
1346        let request = self.to_completion_request(RequestKind::Chat, cx);
1347        let messages = Arc::new(request.messages);
1348        let pending_tool_uses = self
1349            .tool_use
1350            .pending_tool_uses()
1351            .into_iter()
1352            .filter(|tool_use| tool_use.status.is_idle())
1353            .cloned()
1354            .collect::<Vec<_>>();
1355
1356        for tool_use in pending_tool_uses.iter() {
1357            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1358                if tool.needs_confirmation()
1359                    && !AssistantSettings::get_global(cx).always_allow_tool_actions
1360                {
1361                    self.tool_use.confirm_tool_use(
1362                        tool_use.id.clone(),
1363                        tool_use.ui_text.clone(),
1364                        tool_use.input.clone(),
1365                        messages.clone(),
1366                        tool,
1367                    );
1368                    cx.emit(ThreadEvent::ToolConfirmationNeeded);
1369                } else {
1370                    self.run_tool(
1371                        tool_use.id.clone(),
1372                        tool_use.ui_text.clone(),
1373                        tool_use.input.clone(),
1374                        &messages,
1375                        tool,
1376                        cx,
1377                    );
1378                }
1379            } else if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
1380                self.run_tool(
1381                    tool_use.id.clone(),
1382                    tool_use.ui_text.clone(),
1383                    tool_use.input.clone(),
1384                    &messages,
1385                    tool,
1386                    cx,
1387                );
1388            }
1389        }
1390
1391        pending_tool_uses
1392    }
1393
1394    pub fn run_tool(
1395        &mut self,
1396        tool_use_id: LanguageModelToolUseId,
1397        ui_text: impl Into<SharedString>,
1398        input: serde_json::Value,
1399        messages: &[LanguageModelRequestMessage],
1400        tool: Arc<dyn Tool>,
1401        cx: &mut Context<Thread>,
1402    ) {
1403        let task = self.spawn_tool_use(tool_use_id.clone(), messages, input, tool, cx);
1404        self.tool_use
1405            .run_pending_tool(tool_use_id, ui_text.into(), task);
1406    }
1407
1408    fn spawn_tool_use(
1409        &mut self,
1410        tool_use_id: LanguageModelToolUseId,
1411        messages: &[LanguageModelRequestMessage],
1412        input: serde_json::Value,
1413        tool: Arc<dyn Tool>,
1414        cx: &mut Context<Thread>,
1415    ) -> Task<()> {
1416        let tool_name: Arc<str> = tool.name().into();
1417        let run_tool = tool.run(
1418            input,
1419            messages,
1420            self.project.clone(),
1421            self.action_log.clone(),
1422            cx,
1423        );
1424
1425        cx.spawn({
1426            async move |thread: WeakEntity<Thread>, cx| {
1427                let output = run_tool.await;
1428
1429                thread
1430                    .update(cx, |thread, cx| {
1431                        let pending_tool_use = thread.tool_use.insert_tool_output(
1432                            tool_use_id.clone(),
1433                            tool_name,
1434                            output,
1435                        );
1436
1437                        cx.emit(ThreadEvent::ToolFinished {
1438                            tool_use_id,
1439                            pending_tool_use,
1440                            canceled: false,
1441                        });
1442                    })
1443                    .ok();
1444            }
1445        })
1446    }
1447
1448    pub fn attach_tool_results(
1449        &mut self,
1450        updated_context: Vec<AssistantContext>,
1451        cx: &mut Context<Self>,
1452    ) {
1453        self.context.extend(
1454            updated_context
1455                .into_iter()
1456                .map(|context| (context.id(), context)),
1457        );
1458
1459        // Insert a user message to contain the tool results.
1460        self.insert_user_message(
1461            // TODO: Sending up a user message without any content results in the model sending back
1462            // responses that also don't have any content. We currently don't handle this case well,
1463            // so for now we provide some text to keep the model on track.
1464            "Here are the tool results.",
1465            Vec::new(),
1466            None,
1467            cx,
1468        );
1469    }
1470
1471    /// Cancels the last pending completion, if there are any pending.
1472    ///
1473    /// Returns whether a completion was canceled.
1474    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
1475        let canceled = if self.pending_completions.pop().is_some() {
1476            true
1477        } else {
1478            let mut canceled = false;
1479            for pending_tool_use in self.tool_use.cancel_pending() {
1480                canceled = true;
1481                cx.emit(ThreadEvent::ToolFinished {
1482                    tool_use_id: pending_tool_use.id.clone(),
1483                    pending_tool_use: Some(pending_tool_use),
1484                    canceled: true,
1485                });
1486            }
1487            canceled
1488        };
1489        self.finalize_pending_checkpoint(cx);
1490        canceled
1491    }
1492
1493    /// Returns the feedback given to the thread, if any.
1494    pub fn feedback(&self) -> Option<ThreadFeedback> {
1495        self.feedback
1496    }
1497
1498    /// Reports feedback about the thread and stores it in our telemetry backend.
1499    pub fn report_feedback(
1500        &mut self,
1501        feedback: ThreadFeedback,
1502        cx: &mut Context<Self>,
1503    ) -> Task<Result<()>> {
1504        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1505        let serialized_thread = self.serialize(cx);
1506        let thread_id = self.id().clone();
1507        let client = self.project.read(cx).client();
1508        self.feedback = Some(feedback);
1509        cx.notify();
1510
1511        cx.background_spawn(async move {
1512            let final_project_snapshot = final_project_snapshot.await;
1513            let serialized_thread = serialized_thread.await?;
1514            let thread_data =
1515                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1516
1517            let rating = match feedback {
1518                ThreadFeedback::Positive => "positive",
1519                ThreadFeedback::Negative => "negative",
1520            };
1521            telemetry::event!(
1522                "Assistant Thread Rated",
1523                rating,
1524                thread_id,
1525                thread_data,
1526                final_project_snapshot
1527            );
1528            client.telemetry().flush_events();
1529
1530            Ok(())
1531        })
1532    }
1533
1534    /// Create a snapshot of the current project state including git information and unsaved buffers.
1535    fn project_snapshot(
1536        project: Entity<Project>,
1537        cx: &mut Context<Self>,
1538    ) -> Task<Arc<ProjectSnapshot>> {
1539        let git_store = project.read(cx).git_store().clone();
1540        let worktree_snapshots: Vec<_> = project
1541            .read(cx)
1542            .visible_worktrees(cx)
1543            .map(|worktree| Self::worktree_snapshot(worktree, git_store.clone(), cx))
1544            .collect();
1545
1546        cx.spawn(async move |_, cx| {
1547            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1548
1549            let mut unsaved_buffers = Vec::new();
1550            cx.update(|app_cx| {
1551                let buffer_store = project.read(app_cx).buffer_store();
1552                for buffer_handle in buffer_store.read(app_cx).buffers() {
1553                    let buffer = buffer_handle.read(app_cx);
1554                    if buffer.is_dirty() {
1555                        if let Some(file) = buffer.file() {
1556                            let path = file.path().to_string_lossy().to_string();
1557                            unsaved_buffers.push(path);
1558                        }
1559                    }
1560                }
1561            })
1562            .ok();
1563
1564            Arc::new(ProjectSnapshot {
1565                worktree_snapshots,
1566                unsaved_buffer_paths: unsaved_buffers,
1567                timestamp: Utc::now(),
1568            })
1569        })
1570    }
1571
1572    fn worktree_snapshot(
1573        worktree: Entity<project::Worktree>,
1574        git_store: Entity<GitStore>,
1575        cx: &App,
1576    ) -> Task<WorktreeSnapshot> {
1577        cx.spawn(async move |cx| {
1578            // Get worktree path and snapshot
1579            let worktree_info = cx.update(|app_cx| {
1580                let worktree = worktree.read(app_cx);
1581                let path = worktree.abs_path().to_string_lossy().to_string();
1582                let snapshot = worktree.snapshot();
1583                (path, snapshot)
1584            });
1585
1586            let Ok((worktree_path, _snapshot)) = worktree_info else {
1587                return WorktreeSnapshot {
1588                    worktree_path: String::new(),
1589                    git_state: None,
1590                };
1591            };
1592
1593            let git_state = git_store
1594                .update(cx, |git_store, cx| {
1595                    git_store
1596                        .repositories()
1597                        .values()
1598                        .find(|repo| {
1599                            repo.read(cx)
1600                                .abs_path_to_repo_path(&worktree.read(cx).abs_path())
1601                                .is_some()
1602                        })
1603                        .cloned()
1604                })
1605                .ok()
1606                .flatten()
1607                .map(|repo| {
1608                    repo.read_with(cx, |repo, _| {
1609                        let current_branch =
1610                            repo.branch.as_ref().map(|branch| branch.name.to_string());
1611                        repo.send_job(|state, _| async move {
1612                            let RepositoryState::Local { backend, .. } = state else {
1613                                return GitState {
1614                                    remote_url: None,
1615                                    head_sha: None,
1616                                    current_branch,
1617                                    diff: None,
1618                                };
1619                            };
1620
1621                            let remote_url = backend.remote_url("origin");
1622                            let head_sha = backend.head_sha();
1623                            let diff = backend.diff(DiffType::HeadToWorktree).await.ok();
1624
1625                            GitState {
1626                                remote_url,
1627                                head_sha,
1628                                current_branch,
1629                                diff,
1630                            }
1631                        })
1632                    })
1633                });
1634
1635            let git_state = match git_state {
1636                Some(git_state) => match git_state.ok() {
1637                    Some(git_state) => git_state.await.ok(),
1638                    None => None,
1639                },
1640                None => None,
1641            };
1642
1643            WorktreeSnapshot {
1644                worktree_path,
1645                git_state,
1646            }
1647        })
1648    }
1649
1650    pub fn to_markdown(&self, cx: &App) -> Result<String> {
1651        let mut markdown = Vec::new();
1652
1653        if let Some(summary) = self.summary() {
1654            writeln!(markdown, "# {summary}\n")?;
1655        };
1656
1657        for message in self.messages() {
1658            writeln!(
1659                markdown,
1660                "## {role}\n",
1661                role = match message.role {
1662                    Role::User => "User",
1663                    Role::Assistant => "Assistant",
1664                    Role::System => "System",
1665                }
1666            )?;
1667            for segment in &message.segments {
1668                match segment {
1669                    MessageSegment::Text(text) => writeln!(markdown, "{}\n", text)?,
1670                    MessageSegment::Thinking(text) => {
1671                        writeln!(markdown, "<think>{}</think>\n", text)?
1672                    }
1673                }
1674            }
1675
1676            for tool_use in self.tool_uses_for_message(message.id, cx) {
1677                writeln!(
1678                    markdown,
1679                    "**Use Tool: {} ({})**",
1680                    tool_use.name, tool_use.id
1681                )?;
1682                writeln!(markdown, "```json")?;
1683                writeln!(
1684                    markdown,
1685                    "{}",
1686                    serde_json::to_string_pretty(&tool_use.input)?
1687                )?;
1688                writeln!(markdown, "```")?;
1689            }
1690
1691            for tool_result in self.tool_results_for_message(message.id) {
1692                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1693                if tool_result.is_error {
1694                    write!(markdown, " (Error)")?;
1695                }
1696
1697                writeln!(markdown, "**\n")?;
1698                writeln!(markdown, "{}", tool_result.content)?;
1699            }
1700        }
1701
1702        Ok(String::from_utf8_lossy(&markdown).to_string())
1703    }
1704
1705    pub fn keep_edits_in_range(
1706        &mut self,
1707        buffer: Entity<language::Buffer>,
1708        buffer_range: Range<language::Anchor>,
1709        cx: &mut Context<Self>,
1710    ) {
1711        self.action_log.update(cx, |action_log, cx| {
1712            action_log.keep_edits_in_range(buffer, buffer_range, cx)
1713        });
1714    }
1715
1716    pub fn keep_all_edits(&mut self, cx: &mut Context<Self>) {
1717        self.action_log
1718            .update(cx, |action_log, cx| action_log.keep_all_edits(cx));
1719    }
1720
1721    pub fn action_log(&self) -> &Entity<ActionLog> {
1722        &self.action_log
1723    }
1724
1725    pub fn project(&self) -> &Entity<Project> {
1726        &self.project
1727    }
1728
1729    pub fn cumulative_token_usage(&self) -> TokenUsage {
1730        self.cumulative_token_usage.clone()
1731    }
1732
1733    pub fn is_getting_too_long(&self, cx: &App) -> bool {
1734        let model_registry = LanguageModelRegistry::read_global(cx);
1735        let Some(model) = model_registry.active_model() else {
1736            return false;
1737        };
1738
1739        let max_tokens = model.max_token_count();
1740
1741        let current_usage =
1742            self.cumulative_token_usage.input_tokens + self.cumulative_token_usage.output_tokens;
1743
1744        #[cfg(debug_assertions)]
1745        let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD")
1746            .unwrap_or("0.9".to_string())
1747            .parse()
1748            .unwrap();
1749        #[cfg(not(debug_assertions))]
1750        let warning_threshold: f32 = 0.9;
1751
1752        current_usage as f32 >= (max_tokens as f32 * warning_threshold)
1753    }
1754
1755    pub fn deny_tool_use(
1756        &mut self,
1757        tool_use_id: LanguageModelToolUseId,
1758        tool_name: Arc<str>,
1759        cx: &mut Context<Self>,
1760    ) {
1761        let err = Err(anyhow::anyhow!(
1762            "Permission to run tool action denied by user"
1763        ));
1764
1765        self.tool_use
1766            .insert_tool_output(tool_use_id.clone(), tool_name, err);
1767
1768        cx.emit(ThreadEvent::ToolFinished {
1769            tool_use_id,
1770            pending_tool_use: None,
1771            canceled: true,
1772        });
1773    }
1774}
1775
1776#[derive(Debug, Clone)]
1777pub enum ThreadError {
1778    PaymentRequired,
1779    MaxMonthlySpendReached,
1780    Message {
1781        header: SharedString,
1782        message: SharedString,
1783    },
1784}
1785
1786#[derive(Debug, Clone)]
1787pub enum ThreadEvent {
1788    ShowError(ThreadError),
1789    StreamedCompletion,
1790    StreamedAssistantText(MessageId, String),
1791    StreamedAssistantThinking(MessageId, String),
1792    DoneStreaming,
1793    MessageAdded(MessageId),
1794    MessageEdited(MessageId),
1795    MessageDeleted(MessageId),
1796    SummaryChanged,
1797    UsePendingTools,
1798    ToolFinished {
1799        #[allow(unused)]
1800        tool_use_id: LanguageModelToolUseId,
1801        /// The pending tool use that corresponds to this tool.
1802        pending_tool_use: Option<PendingToolUse>,
1803        /// Whether the tool was canceled by the user.
1804        canceled: bool,
1805    },
1806    CheckpointChanged,
1807    ToolConfirmationNeeded,
1808}
1809
1810impl EventEmitter<ThreadEvent> for Thread {}
1811
1812struct PendingCompletion {
1813    id: usize,
1814    _task: Task<()>,
1815}