thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::sync::Arc;
   4
   5use anyhow::{Context as _, Result};
   6use assistant_tool::{ActionLog, ToolWorkingSet};
   7use chrono::{DateTime, Utc};
   8use collections::{BTreeMap, HashMap, HashSet};
   9use futures::future::Shared;
  10use futures::{FutureExt, StreamExt as _};
  11use git;
  12use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
  13use language_model::{
  14    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  15    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  16    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  17    Role, StopReason, TokenUsage,
  18};
  19use project::git::GitStoreCheckpoint;
  20use project::Project;
  21use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
  22use scripting_tool::{ScriptingSession, ScriptingTool};
  23use serde::{Deserialize, Serialize};
  24use util::{post_inc, ResultExt, TryFutureExt as _};
  25use uuid::Uuid;
  26
  27use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
  28use crate::thread_store::{
  29    SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
  30};
  31use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
  32
  33#[derive(Debug, Clone, Copy)]
  34pub enum RequestKind {
  35    Chat,
  36    /// Used when summarizing a thread.
  37    Summarize,
  38}
  39
  40#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  41pub struct ThreadId(Arc<str>);
  42
  43impl ThreadId {
  44    pub fn new() -> Self {
  45        Self(Uuid::new_v4().to_string().into())
  46    }
  47}
  48
  49impl std::fmt::Display for ThreadId {
  50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  51        write!(f, "{}", self.0)
  52    }
  53}
  54
  55#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  56pub struct MessageId(pub(crate) usize);
  57
  58impl MessageId {
  59    fn post_inc(&mut self) -> Self {
  60        Self(post_inc(&mut self.0))
  61    }
  62}
  63
  64/// A message in a [`Thread`].
  65#[derive(Debug, Clone)]
  66pub struct Message {
  67    pub id: MessageId,
  68    pub role: Role,
  69    pub text: String,
  70}
  71
  72#[derive(Debug, Clone, Serialize, Deserialize)]
  73pub struct ProjectSnapshot {
  74    pub worktree_snapshots: Vec<WorktreeSnapshot>,
  75    pub unsaved_buffer_paths: Vec<String>,
  76    pub timestamp: DateTime<Utc>,
  77}
  78
  79#[derive(Debug, Clone, Serialize, Deserialize)]
  80pub struct WorktreeSnapshot {
  81    pub worktree_path: String,
  82    pub git_state: Option<GitState>,
  83}
  84
  85#[derive(Debug, Clone, Serialize, Deserialize)]
  86pub struct GitState {
  87    pub remote_url: Option<String>,
  88    pub head_sha: Option<String>,
  89    pub current_branch: Option<String>,
  90    pub diff: Option<String>,
  91}
  92
  93#[derive(Clone)]
  94pub struct ThreadCheckpoint {
  95    message_id: MessageId,
  96    git_checkpoint: GitStoreCheckpoint,
  97}
  98
  99/// A thread of conversation with the LLM.
 100pub struct Thread {
 101    id: ThreadId,
 102    updated_at: DateTime<Utc>,
 103    summary: Option<SharedString>,
 104    pending_summary: Task<Option<()>>,
 105    messages: Vec<Message>,
 106    next_message_id: MessageId,
 107    context: BTreeMap<ContextId, ContextSnapshot>,
 108    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 109    checkpoints_by_message: HashMap<MessageId, GitStoreCheckpoint>,
 110    completion_count: usize,
 111    pending_completions: Vec<PendingCompletion>,
 112    project: Entity<Project>,
 113    prompt_builder: Arc<PromptBuilder>,
 114    tools: Arc<ToolWorkingSet>,
 115    tool_use: ToolUseState,
 116    action_log: Entity<ActionLog>,
 117    scripting_session: Entity<ScriptingSession>,
 118    scripting_tool_use: ToolUseState,
 119    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 120    cumulative_token_usage: TokenUsage,
 121}
 122
 123impl Thread {
 124    pub fn new(
 125        project: Entity<Project>,
 126        tools: Arc<ToolWorkingSet>,
 127        prompt_builder: Arc<PromptBuilder>,
 128        cx: &mut Context<Self>,
 129    ) -> Self {
 130        Self {
 131            id: ThreadId::new(),
 132            updated_at: Utc::now(),
 133            summary: None,
 134            pending_summary: Task::ready(None),
 135            messages: Vec::new(),
 136            next_message_id: MessageId(0),
 137            context: BTreeMap::default(),
 138            context_by_message: HashMap::default(),
 139            checkpoints_by_message: HashMap::default(),
 140            completion_count: 0,
 141            pending_completions: Vec::new(),
 142            project: project.clone(),
 143            prompt_builder,
 144            tools,
 145            tool_use: ToolUseState::new(),
 146            scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
 147            scripting_tool_use: ToolUseState::new(),
 148            action_log: cx.new(|_| ActionLog::new()),
 149            initial_project_snapshot: {
 150                let project_snapshot = Self::project_snapshot(project, cx);
 151                cx.foreground_executor()
 152                    .spawn(async move { Some(project_snapshot.await) })
 153                    .shared()
 154            },
 155            cumulative_token_usage: TokenUsage::default(),
 156        }
 157    }
 158
 159    pub fn deserialize(
 160        id: ThreadId,
 161        serialized: SerializedThread,
 162        project: Entity<Project>,
 163        tools: Arc<ToolWorkingSet>,
 164        prompt_builder: Arc<PromptBuilder>,
 165        cx: &mut Context<Self>,
 166    ) -> Self {
 167        let next_message_id = MessageId(
 168            serialized
 169                .messages
 170                .last()
 171                .map(|message| message.id.0 + 1)
 172                .unwrap_or(0),
 173        );
 174        let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
 175            name != ScriptingTool::NAME
 176        });
 177        let scripting_tool_use =
 178            ToolUseState::from_serialized_messages(&serialized.messages, |name| {
 179                name == ScriptingTool::NAME
 180            });
 181        let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
 182
 183        Self {
 184            id,
 185            updated_at: serialized.updated_at,
 186            summary: Some(serialized.summary),
 187            pending_summary: Task::ready(None),
 188            messages: serialized
 189                .messages
 190                .into_iter()
 191                .map(|message| Message {
 192                    id: message.id,
 193                    role: message.role,
 194                    text: message.text,
 195                })
 196                .collect(),
 197            next_message_id,
 198            context: BTreeMap::default(),
 199            context_by_message: HashMap::default(),
 200            checkpoints_by_message: HashMap::default(),
 201            completion_count: 0,
 202            pending_completions: Vec::new(),
 203            project,
 204            prompt_builder,
 205            tools,
 206            tool_use,
 207            action_log: cx.new(|_| ActionLog::new()),
 208            scripting_session,
 209            scripting_tool_use,
 210            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 211            // TODO: persist token usage?
 212            cumulative_token_usage: TokenUsage::default(),
 213        }
 214    }
 215
 216    pub fn id(&self) -> &ThreadId {
 217        &self.id
 218    }
 219
 220    pub fn is_empty(&self) -> bool {
 221        self.messages.is_empty()
 222    }
 223
 224    pub fn updated_at(&self) -> DateTime<Utc> {
 225        self.updated_at
 226    }
 227
 228    pub fn touch_updated_at(&mut self) {
 229        self.updated_at = Utc::now();
 230    }
 231
 232    pub fn summary(&self) -> Option<SharedString> {
 233        self.summary.clone()
 234    }
 235
 236    pub fn summary_or_default(&self) -> SharedString {
 237        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 238        self.summary.clone().unwrap_or(DEFAULT)
 239    }
 240
 241    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 242        self.summary = Some(summary.into());
 243        cx.emit(ThreadEvent::SummaryChanged);
 244    }
 245
 246    pub fn message(&self, id: MessageId) -> Option<&Message> {
 247        self.messages.iter().find(|message| message.id == id)
 248    }
 249
 250    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 251        self.messages.iter()
 252    }
 253
 254    pub fn is_generating(&self) -> bool {
 255        !self.pending_completions.is_empty() || !self.all_tools_finished()
 256    }
 257
 258    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 259        &self.tools
 260    }
 261
 262    pub fn checkpoint_for_message(&self, id: MessageId) -> Option<ThreadCheckpoint> {
 263        let checkpoint = self.checkpoints_by_message.get(&id).cloned()?;
 264        Some(ThreadCheckpoint {
 265            message_id: id,
 266            git_checkpoint: checkpoint,
 267        })
 268    }
 269
 270    pub fn restore_checkpoint(
 271        &mut self,
 272        checkpoint: ThreadCheckpoint,
 273        cx: &mut Context<Self>,
 274    ) -> Task<Result<()>> {
 275        let project = self.project.read(cx);
 276        let restore = project
 277            .git_store()
 278            .read(cx)
 279            .restore_checkpoint(checkpoint.git_checkpoint, cx);
 280        cx.spawn(async move |this, cx| {
 281            restore.await?;
 282            this.update(cx, |this, cx| this.truncate(checkpoint.message_id, cx))
 283        })
 284    }
 285
 286    pub fn truncate(&mut self, message_id: MessageId, cx: &mut Context<Self>) {
 287        let Some(message_ix) = self
 288            .messages
 289            .iter()
 290            .rposition(|message| message.id == message_id)
 291        else {
 292            return;
 293        };
 294        for deleted_message in self.messages.drain(message_ix..) {
 295            self.context_by_message.remove(&deleted_message.id);
 296            self.checkpoints_by_message.remove(&deleted_message.id);
 297        }
 298        cx.notify();
 299    }
 300
 301    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
 302        let context = self.context_by_message.get(&id)?;
 303        Some(
 304            context
 305                .into_iter()
 306                .filter_map(|context_id| self.context.get(&context_id))
 307                .cloned()
 308                .collect::<Vec<_>>(),
 309        )
 310    }
 311
 312    /// Returns whether all of the tool uses have finished running.
 313    pub fn all_tools_finished(&self) -> bool {
 314        let mut all_pending_tool_uses = self
 315            .tool_use
 316            .pending_tool_uses()
 317            .into_iter()
 318            .chain(self.scripting_tool_use.pending_tool_uses());
 319
 320        // If the only pending tool uses left are the ones with errors, then
 321        // that means that we've finished running all of the pending tools.
 322        all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
 323    }
 324
 325    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
 326        self.tool_use.tool_uses_for_message(id)
 327    }
 328
 329    pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
 330        self.scripting_tool_use.tool_uses_for_message(id)
 331    }
 332
 333    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 334        self.tool_use.tool_results_for_message(id)
 335    }
 336
 337    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 338        self.tool_use.tool_result(id)
 339    }
 340
 341    pub fn scripting_tool_results_for_message(
 342        &self,
 343        id: MessageId,
 344    ) -> Vec<&LanguageModelToolResult> {
 345        self.scripting_tool_use.tool_results_for_message(id)
 346    }
 347
 348    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 349        self.tool_use.message_has_tool_results(message_id)
 350    }
 351
 352    pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
 353        self.scripting_tool_use.message_has_tool_results(message_id)
 354    }
 355
 356    pub fn insert_user_message(
 357        &mut self,
 358        text: impl Into<String>,
 359        context: Vec<ContextSnapshot>,
 360        checkpoint: Option<GitStoreCheckpoint>,
 361        cx: &mut Context<Self>,
 362    ) -> MessageId {
 363        let message_id = self.insert_message(Role::User, text, cx);
 364        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
 365        self.context
 366            .extend(context.into_iter().map(|context| (context.id, context)));
 367        self.context_by_message.insert(message_id, context_ids);
 368        if let Some(checkpoint) = checkpoint {
 369            self.checkpoints_by_message.insert(message_id, checkpoint);
 370        }
 371        message_id
 372    }
 373
 374    pub fn insert_message(
 375        &mut self,
 376        role: Role,
 377        text: impl Into<String>,
 378        cx: &mut Context<Self>,
 379    ) -> MessageId {
 380        let id = self.next_message_id.post_inc();
 381        self.messages.push(Message {
 382            id,
 383            role,
 384            text: text.into(),
 385        });
 386        self.touch_updated_at();
 387        cx.emit(ThreadEvent::MessageAdded(id));
 388        id
 389    }
 390
 391    pub fn edit_message(
 392        &mut self,
 393        id: MessageId,
 394        new_role: Role,
 395        new_text: String,
 396        cx: &mut Context<Self>,
 397    ) -> bool {
 398        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 399            return false;
 400        };
 401        message.role = new_role;
 402        message.text = new_text;
 403        self.touch_updated_at();
 404        cx.emit(ThreadEvent::MessageEdited(id));
 405        true
 406    }
 407
 408    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 409        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 410            return false;
 411        };
 412        self.messages.remove(index);
 413        self.context_by_message.remove(&id);
 414        self.touch_updated_at();
 415        cx.emit(ThreadEvent::MessageDeleted(id));
 416        true
 417    }
 418
 419    /// Returns the representation of this [`Thread`] in a textual form.
 420    ///
 421    /// This is the representation we use when attaching a thread as context to another thread.
 422    pub fn text(&self) -> String {
 423        let mut text = String::new();
 424
 425        for message in &self.messages {
 426            text.push_str(match message.role {
 427                language_model::Role::User => "User:",
 428                language_model::Role::Assistant => "Assistant:",
 429                language_model::Role::System => "System:",
 430            });
 431            text.push('\n');
 432
 433            text.push_str(&message.text);
 434            text.push('\n');
 435        }
 436
 437        text
 438    }
 439
 440    /// Serializes this thread into a format for storage or telemetry.
 441    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 442        let initial_project_snapshot = self.initial_project_snapshot.clone();
 443        cx.spawn(async move |this, cx| {
 444            let initial_project_snapshot = initial_project_snapshot.await;
 445            this.read_with(cx, |this, _| SerializedThread {
 446                summary: this.summary_or_default(),
 447                updated_at: this.updated_at(),
 448                messages: this
 449                    .messages()
 450                    .map(|message| SerializedMessage {
 451                        id: message.id,
 452                        role: message.role,
 453                        text: message.text.clone(),
 454                        tool_uses: this
 455                            .tool_uses_for_message(message.id)
 456                            .into_iter()
 457                            .chain(this.scripting_tool_uses_for_message(message.id))
 458                            .map(|tool_use| SerializedToolUse {
 459                                id: tool_use.id,
 460                                name: tool_use.name,
 461                                input: tool_use.input,
 462                            })
 463                            .collect(),
 464                        tool_results: this
 465                            .tool_results_for_message(message.id)
 466                            .into_iter()
 467                            .chain(this.scripting_tool_results_for_message(message.id))
 468                            .map(|tool_result| SerializedToolResult {
 469                                tool_use_id: tool_result.tool_use_id.clone(),
 470                                is_error: tool_result.is_error,
 471                                content: tool_result.content.clone(),
 472                            })
 473                            .collect(),
 474                    })
 475                    .collect(),
 476                initial_project_snapshot,
 477            })
 478        })
 479    }
 480
 481    pub fn send_to_model(
 482        &mut self,
 483        model: Arc<dyn LanguageModel>,
 484        request_kind: RequestKind,
 485        cx: &mut Context<Self>,
 486    ) {
 487        let mut request = self.to_completion_request(request_kind, cx);
 488        request.tools = {
 489            let mut tools = Vec::new();
 490
 491            if self.tools.is_scripting_tool_enabled() {
 492                tools.push(LanguageModelRequestTool {
 493                    name: ScriptingTool::NAME.into(),
 494                    description: ScriptingTool::DESCRIPTION.into(),
 495                    input_schema: ScriptingTool::input_schema(),
 496                });
 497            }
 498
 499            tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 500                LanguageModelRequestTool {
 501                    name: tool.name(),
 502                    description: tool.description(),
 503                    input_schema: tool.input_schema(),
 504                }
 505            }));
 506
 507            tools
 508        };
 509
 510        self.stream_completion(request, model, cx);
 511    }
 512
 513    pub fn to_completion_request(
 514        &self,
 515        request_kind: RequestKind,
 516        cx: &App,
 517    ) -> LanguageModelRequest {
 518        let worktree_root_names = self
 519            .project
 520            .read(cx)
 521            .visible_worktrees(cx)
 522            .map(|worktree| {
 523                let worktree = worktree.read(cx);
 524                AssistantSystemPromptWorktree {
 525                    root_name: worktree.root_name().into(),
 526                    abs_path: worktree.abs_path(),
 527                }
 528            })
 529            .collect::<Vec<_>>();
 530        let system_prompt = self
 531            .prompt_builder
 532            .generate_assistant_system_prompt(worktree_root_names)
 533            .context("failed to generate assistant system prompt")
 534            .log_err()
 535            .unwrap_or_default();
 536
 537        let mut request = LanguageModelRequest {
 538            messages: vec![LanguageModelRequestMessage {
 539                role: Role::System,
 540                content: vec![MessageContent::Text(system_prompt)],
 541                cache: true,
 542            }],
 543            tools: Vec::new(),
 544            stop: Vec::new(),
 545            temperature: None,
 546        };
 547
 548        let mut referenced_context_ids = HashSet::default();
 549
 550        for message in &self.messages {
 551            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 552                referenced_context_ids.extend(context_ids);
 553            }
 554
 555            let mut request_message = LanguageModelRequestMessage {
 556                role: message.role,
 557                content: Vec::new(),
 558                cache: false,
 559            };
 560
 561            match request_kind {
 562                RequestKind::Chat => {
 563                    self.tool_use
 564                        .attach_tool_results(message.id, &mut request_message);
 565                    self.scripting_tool_use
 566                        .attach_tool_results(message.id, &mut request_message);
 567                }
 568                RequestKind::Summarize => {
 569                    // We don't care about tool use during summarization.
 570                }
 571            }
 572
 573            if !message.text.is_empty() {
 574                request_message
 575                    .content
 576                    .push(MessageContent::Text(message.text.clone()));
 577            }
 578
 579            match request_kind {
 580                RequestKind::Chat => {
 581                    self.tool_use
 582                        .attach_tool_uses(message.id, &mut request_message);
 583                    self.scripting_tool_use
 584                        .attach_tool_uses(message.id, &mut request_message);
 585                }
 586                RequestKind::Summarize => {
 587                    // We don't care about tool use during summarization.
 588                }
 589            };
 590
 591            request.messages.push(request_message);
 592        }
 593
 594        if !referenced_context_ids.is_empty() {
 595            let mut context_message = LanguageModelRequestMessage {
 596                role: Role::User,
 597                content: Vec::new(),
 598                cache: false,
 599            };
 600
 601            let referenced_context = referenced_context_ids
 602                .into_iter()
 603                .filter_map(|context_id| self.context.get(context_id))
 604                .cloned();
 605            attach_context_to_message(&mut context_message, referenced_context);
 606
 607            request.messages.push(context_message);
 608        }
 609
 610        self.attach_stale_files(&mut request.messages, cx);
 611
 612        request
 613    }
 614
 615    fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
 616        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 617
 618        let mut stale_message = String::new();
 619
 620        for stale_file in self.action_log.read(cx).stale_buffers(cx) {
 621            let Some(file) = stale_file.read(cx).file() else {
 622                continue;
 623            };
 624
 625            if stale_message.is_empty() {
 626                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
 627            }
 628
 629            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 630        }
 631
 632        if !stale_message.is_empty() {
 633            let context_message = LanguageModelRequestMessage {
 634                role: Role::User,
 635                content: vec![stale_message.into()],
 636                cache: false,
 637            };
 638
 639            messages.push(context_message);
 640        }
 641    }
 642
 643    pub fn stream_completion(
 644        &mut self,
 645        request: LanguageModelRequest,
 646        model: Arc<dyn LanguageModel>,
 647        cx: &mut Context<Self>,
 648    ) {
 649        let pending_completion_id = post_inc(&mut self.completion_count);
 650
 651        let task = cx.spawn(async move |thread, cx| {
 652            let stream = model.stream_completion(request, &cx);
 653            let initial_token_usage =
 654                thread.read_with(cx, |thread, _cx| thread.cumulative_token_usage.clone());
 655            let stream_completion = async {
 656                let mut events = stream.await?;
 657                let mut stop_reason = StopReason::EndTurn;
 658                let mut current_token_usage = TokenUsage::default();
 659
 660                while let Some(event) = events.next().await {
 661                    let event = event?;
 662
 663                    thread.update(cx, |thread, cx| {
 664                        match event {
 665                            LanguageModelCompletionEvent::StartMessage { .. } => {
 666                                thread.insert_message(Role::Assistant, String::new(), cx);
 667                            }
 668                            LanguageModelCompletionEvent::Stop(reason) => {
 669                                stop_reason = reason;
 670                            }
 671                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
 672                                thread.cumulative_token_usage =
 673                                    thread.cumulative_token_usage.clone() + token_usage.clone()
 674                                        - current_token_usage.clone();
 675                                current_token_usage = token_usage;
 676                            }
 677                            LanguageModelCompletionEvent::Text(chunk) => {
 678                                if let Some(last_message) = thread.messages.last_mut() {
 679                                    if last_message.role == Role::Assistant {
 680                                        last_message.text.push_str(&chunk);
 681                                        cx.emit(ThreadEvent::StreamedAssistantText(
 682                                            last_message.id,
 683                                            chunk,
 684                                        ));
 685                                    } else {
 686                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
 687                                        // of a new Assistant response.
 688                                        //
 689                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
 690                                        // will result in duplicating the text of the chunk in the rendered Markdown.
 691                                        thread.insert_message(Role::Assistant, chunk, cx);
 692                                    };
 693                                }
 694                            }
 695                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
 696                                if let Some(last_assistant_message) = thread
 697                                    .messages
 698                                    .iter()
 699                                    .rfind(|message| message.role == Role::Assistant)
 700                                {
 701                                    if tool_use.name.as_ref() == ScriptingTool::NAME {
 702                                        thread
 703                                            .scripting_tool_use
 704                                            .request_tool_use(last_assistant_message.id, tool_use);
 705                                    } else {
 706                                        thread
 707                                            .tool_use
 708                                            .request_tool_use(last_assistant_message.id, tool_use);
 709                                    }
 710                                }
 711                            }
 712                        }
 713
 714                        thread.touch_updated_at();
 715                        cx.emit(ThreadEvent::StreamedCompletion);
 716                        cx.notify();
 717                    })?;
 718
 719                    smol::future::yield_now().await;
 720                }
 721
 722                thread.update(cx, |thread, cx| {
 723                    thread
 724                        .pending_completions
 725                        .retain(|completion| completion.id != pending_completion_id);
 726
 727                    if thread.summary.is_none() && thread.messages.len() >= 2 {
 728                        thread.summarize(cx);
 729                    }
 730                })?;
 731
 732                anyhow::Ok(stop_reason)
 733            };
 734
 735            let result = stream_completion.await;
 736
 737            thread
 738                .update(cx, |thread, cx| {
 739                    match result.as_ref() {
 740                        Ok(stop_reason) => match stop_reason {
 741                            StopReason::ToolUse => {
 742                                cx.emit(ThreadEvent::UsePendingTools);
 743                            }
 744                            StopReason::EndTurn => {}
 745                            StopReason::MaxTokens => {}
 746                        },
 747                        Err(error) => {
 748                            if error.is::<PaymentRequiredError>() {
 749                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
 750                            } else if error.is::<MaxMonthlySpendReachedError>() {
 751                                cx.emit(ThreadEvent::ShowError(
 752                                    ThreadError::MaxMonthlySpendReached,
 753                                ));
 754                            } else {
 755                                let error_message = error
 756                                    .chain()
 757                                    .map(|err| err.to_string())
 758                                    .collect::<Vec<_>>()
 759                                    .join("\n");
 760                                cx.emit(ThreadEvent::ShowError(ThreadError::Message(
 761                                    SharedString::from(error_message.clone()),
 762                                )));
 763                            }
 764
 765                            thread.cancel_last_completion(cx);
 766                        }
 767                    }
 768                    cx.emit(ThreadEvent::DoneStreaming);
 769
 770                    if let Ok(initial_usage) = initial_token_usage {
 771                        let usage = thread.cumulative_token_usage.clone() - initial_usage;
 772
 773                        telemetry::event!(
 774                            "Assistant Thread Completion",
 775                            thread_id = thread.id().to_string(),
 776                            model = model.telemetry_id(),
 777                            model_provider = model.provider_id().to_string(),
 778                            input_tokens = usage.input_tokens,
 779                            output_tokens = usage.output_tokens,
 780                            cache_creation_input_tokens = usage.cache_creation_input_tokens,
 781                            cache_read_input_tokens = usage.cache_read_input_tokens,
 782                        );
 783                    }
 784                })
 785                .ok();
 786        });
 787
 788        self.pending_completions.push(PendingCompletion {
 789            id: pending_completion_id,
 790            _task: task,
 791        });
 792    }
 793
 794    pub fn summarize(&mut self, cx: &mut Context<Self>) {
 795        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
 796            return;
 797        };
 798        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
 799            return;
 800        };
 801
 802        if !provider.is_authenticated(cx) {
 803            return;
 804        }
 805
 806        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
 807        request.messages.push(LanguageModelRequestMessage {
 808            role: Role::User,
 809            content: vec![
 810                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
 811                    .into(),
 812            ],
 813            cache: false,
 814        });
 815
 816        self.pending_summary = cx.spawn(async move |this, cx| {
 817            async move {
 818                let stream = model.stream_completion_text(request, &cx);
 819                let mut messages = stream.await?;
 820
 821                let mut new_summary = String::new();
 822                while let Some(message) = messages.stream.next().await {
 823                    let text = message?;
 824                    let mut lines = text.lines();
 825                    new_summary.extend(lines.next());
 826
 827                    // Stop if the LLM generated multiple lines.
 828                    if lines.next().is_some() {
 829                        break;
 830                    }
 831                }
 832
 833                this.update(cx, |this, cx| {
 834                    if !new_summary.is_empty() {
 835                        this.summary = Some(new_summary.into());
 836                    }
 837
 838                    cx.emit(ThreadEvent::SummaryChanged);
 839                })?;
 840
 841                anyhow::Ok(())
 842            }
 843            .log_err()
 844            .await
 845        });
 846    }
 847
 848    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
 849        let request = self.to_completion_request(RequestKind::Chat, cx);
 850        let pending_tool_uses = self
 851            .tool_use
 852            .pending_tool_uses()
 853            .into_iter()
 854            .filter(|tool_use| tool_use.status.is_idle())
 855            .cloned()
 856            .collect::<Vec<_>>();
 857
 858        for tool_use in pending_tool_uses {
 859            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
 860                let task = tool.run(
 861                    tool_use.input,
 862                    &request.messages,
 863                    self.project.clone(),
 864                    self.action_log.clone(),
 865                    cx,
 866                );
 867
 868                self.insert_tool_output(tool_use.id.clone(), task, cx);
 869            }
 870        }
 871
 872        let pending_scripting_tool_uses = self
 873            .scripting_tool_use
 874            .pending_tool_uses()
 875            .into_iter()
 876            .filter(|tool_use| tool_use.status.is_idle())
 877            .cloned()
 878            .collect::<Vec<_>>();
 879
 880        for scripting_tool_use in pending_scripting_tool_uses {
 881            let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) {
 882                Err(err) => Task::ready(Err(err.into())),
 883                Ok(input) => {
 884                    let (script_id, script_task) =
 885                        self.scripting_session.update(cx, move |session, cx| {
 886                            session.run_script(input.lua_script, cx)
 887                        });
 888
 889                    let session = self.scripting_session.clone();
 890                    cx.spawn(async move |_, cx| {
 891                        script_task.await;
 892
 893                        let message = session.read_with(cx, |session, _cx| {
 894                            // Using a id to get the script output seems impractical.
 895                            // Why not just include it in the Task result?
 896                            // This is because we'll later report the script state as it runs,
 897                            session
 898                                .get(script_id)
 899                                .output_message_for_llm()
 900                                .expect("Script shouldn't still be running")
 901                        })?;
 902
 903                        Ok(message)
 904                    })
 905                }
 906            };
 907
 908            self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
 909        }
 910    }
 911
 912    pub fn insert_tool_output(
 913        &mut self,
 914        tool_use_id: LanguageModelToolUseId,
 915        output: Task<Result<String>>,
 916        cx: &mut Context<Self>,
 917    ) {
 918        let insert_output_task = cx.spawn({
 919            let tool_use_id = tool_use_id.clone();
 920            async move |thread, cx| {
 921                let output = output.await;
 922                thread
 923                    .update(cx, |thread, cx| {
 924                        let pending_tool_use = thread
 925                            .tool_use
 926                            .insert_tool_output(tool_use_id.clone(), output);
 927
 928                        cx.emit(ThreadEvent::ToolFinished {
 929                            tool_use_id,
 930                            pending_tool_use,
 931                            canceled: false,
 932                        });
 933                    })
 934                    .ok();
 935            }
 936        });
 937
 938        self.tool_use
 939            .run_pending_tool(tool_use_id, insert_output_task);
 940    }
 941
 942    pub fn insert_scripting_tool_output(
 943        &mut self,
 944        tool_use_id: LanguageModelToolUseId,
 945        output: Task<Result<String>>,
 946        cx: &mut Context<Self>,
 947    ) {
 948        let insert_output_task = cx.spawn({
 949            let tool_use_id = tool_use_id.clone();
 950            async move |thread, cx| {
 951                let output = output.await;
 952                thread
 953                    .update(cx, |thread, cx| {
 954                        let pending_tool_use = thread
 955                            .scripting_tool_use
 956                            .insert_tool_output(tool_use_id.clone(), output);
 957
 958                        cx.emit(ThreadEvent::ToolFinished {
 959                            tool_use_id,
 960                            pending_tool_use,
 961                            canceled: false,
 962                        });
 963                    })
 964                    .ok();
 965            }
 966        });
 967
 968        self.scripting_tool_use
 969            .run_pending_tool(tool_use_id, insert_output_task);
 970    }
 971
 972    pub fn attach_tool_results(
 973        &mut self,
 974        updated_context: Vec<ContextSnapshot>,
 975        cx: &mut Context<Self>,
 976    ) {
 977        self.context.extend(
 978            updated_context
 979                .into_iter()
 980                .map(|context| (context.id, context)),
 981        );
 982
 983        // Insert a user message to contain the tool results.
 984        self.insert_user_message(
 985            // TODO: Sending up a user message without any content results in the model sending back
 986            // responses that also don't have any content. We currently don't handle this case well,
 987            // so for now we provide some text to keep the model on track.
 988            "Here are the tool results.",
 989            Vec::new(),
 990            None,
 991            cx,
 992        );
 993    }
 994
 995    /// Cancels the last pending completion, if there are any pending.
 996    ///
 997    /// Returns whether a completion was canceled.
 998    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
 999        if self.pending_completions.pop().is_some() {
1000            true
1001        } else {
1002            let mut canceled = false;
1003            for pending_tool_use in self.tool_use.cancel_pending() {
1004                canceled = true;
1005                cx.emit(ThreadEvent::ToolFinished {
1006                    tool_use_id: pending_tool_use.id.clone(),
1007                    pending_tool_use: Some(pending_tool_use),
1008                    canceled: true,
1009                });
1010            }
1011            canceled
1012        }
1013    }
1014
1015    /// Reports feedback about the thread and stores it in our telemetry backend.
1016    pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
1017        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
1018        let serialized_thread = self.serialize(cx);
1019        let thread_id = self.id().clone();
1020        let client = self.project.read(cx).client();
1021
1022        cx.background_spawn(async move {
1023            let final_project_snapshot = final_project_snapshot.await;
1024            let serialized_thread = serialized_thread.await?;
1025            let thread_data =
1026                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
1027
1028            let rating = if is_positive { "positive" } else { "negative" };
1029            telemetry::event!(
1030                "Assistant Thread Rated",
1031                rating,
1032                thread_id,
1033                thread_data,
1034                final_project_snapshot
1035            );
1036            client.telemetry().flush_events();
1037
1038            Ok(())
1039        })
1040    }
1041
1042    /// Create a snapshot of the current project state including git information and unsaved buffers.
1043    fn project_snapshot(
1044        project: Entity<Project>,
1045        cx: &mut Context<Self>,
1046    ) -> Task<Arc<ProjectSnapshot>> {
1047        let worktree_snapshots: Vec<_> = project
1048            .read(cx)
1049            .visible_worktrees(cx)
1050            .map(|worktree| Self::worktree_snapshot(worktree, cx))
1051            .collect();
1052
1053        cx.spawn(async move |_, cx| {
1054            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
1055
1056            let mut unsaved_buffers = Vec::new();
1057            cx.update(|app_cx| {
1058                let buffer_store = project.read(app_cx).buffer_store();
1059                for buffer_handle in buffer_store.read(app_cx).buffers() {
1060                    let buffer = buffer_handle.read(app_cx);
1061                    if buffer.is_dirty() {
1062                        if let Some(file) = buffer.file() {
1063                            let path = file.path().to_string_lossy().to_string();
1064                            unsaved_buffers.push(path);
1065                        }
1066                    }
1067                }
1068            })
1069            .ok();
1070
1071            Arc::new(ProjectSnapshot {
1072                worktree_snapshots,
1073                unsaved_buffer_paths: unsaved_buffers,
1074                timestamp: Utc::now(),
1075            })
1076        })
1077    }
1078
1079    fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
1080        cx.spawn(async move |cx| {
1081            // Get worktree path and snapshot
1082            let worktree_info = cx.update(|app_cx| {
1083                let worktree = worktree.read(app_cx);
1084                let path = worktree.abs_path().to_string_lossy().to_string();
1085                let snapshot = worktree.snapshot();
1086                (path, snapshot)
1087            });
1088
1089            let Ok((worktree_path, snapshot)) = worktree_info else {
1090                return WorktreeSnapshot {
1091                    worktree_path: String::new(),
1092                    git_state: None,
1093                };
1094            };
1095
1096            // Extract git information
1097            let git_state = match snapshot.repositories().first() {
1098                None => None,
1099                Some(repo_entry) => {
1100                    // Get branch information
1101                    let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
1102
1103                    // Get repository info
1104                    let repo_result = worktree.read_with(cx, |worktree, _cx| {
1105                        if let project::Worktree::Local(local_worktree) = &worktree {
1106                            local_worktree.get_local_repo(repo_entry).map(|local_repo| {
1107                                let repo = local_repo.repo();
1108                                (repo.remote_url("origin"), repo.head_sha(), repo.clone())
1109                            })
1110                        } else {
1111                            None
1112                        }
1113                    });
1114
1115                    match repo_result {
1116                        Ok(Some((remote_url, head_sha, repository))) => {
1117                            // Get diff asynchronously
1118                            let diff = repository
1119                                .diff(git::repository::DiffType::HeadToWorktree, cx.clone())
1120                                .await
1121                                .ok();
1122
1123                            Some(GitState {
1124                                remote_url,
1125                                head_sha,
1126                                current_branch,
1127                                diff,
1128                            })
1129                        }
1130                        Err(_) | Ok(None) => None,
1131                    }
1132                }
1133            };
1134
1135            WorktreeSnapshot {
1136                worktree_path,
1137                git_state,
1138            }
1139        })
1140    }
1141
1142    pub fn to_markdown(&self) -> Result<String> {
1143        let mut markdown = Vec::new();
1144
1145        if let Some(summary) = self.summary() {
1146            writeln!(markdown, "# {summary}\n")?;
1147        };
1148
1149        for message in self.messages() {
1150            writeln!(
1151                markdown,
1152                "## {role}\n",
1153                role = match message.role {
1154                    Role::User => "User",
1155                    Role::Assistant => "Assistant",
1156                    Role::System => "System",
1157                }
1158            )?;
1159            writeln!(markdown, "{}\n", message.text)?;
1160
1161            for tool_use in self.tool_uses_for_message(message.id) {
1162                writeln!(
1163                    markdown,
1164                    "**Use Tool: {} ({})**",
1165                    tool_use.name, tool_use.id
1166                )?;
1167                writeln!(markdown, "```json")?;
1168                writeln!(
1169                    markdown,
1170                    "{}",
1171                    serde_json::to_string_pretty(&tool_use.input)?
1172                )?;
1173                writeln!(markdown, "```")?;
1174            }
1175
1176            for tool_result in self.tool_results_for_message(message.id) {
1177                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1178                if tool_result.is_error {
1179                    write!(markdown, " (Error)")?;
1180                }
1181
1182                writeln!(markdown, "**\n")?;
1183                writeln!(markdown, "{}", tool_result.content)?;
1184            }
1185        }
1186
1187        Ok(String::from_utf8_lossy(&markdown).to_string())
1188    }
1189
1190    pub fn action_log(&self) -> &Entity<ActionLog> {
1191        &self.action_log
1192    }
1193
1194    pub fn project(&self) -> &Entity<Project> {
1195        &self.project
1196    }
1197
1198    pub fn cumulative_token_usage(&self) -> TokenUsage {
1199        self.cumulative_token_usage.clone()
1200    }
1201}
1202
1203#[derive(Debug, Clone)]
1204pub enum ThreadError {
1205    PaymentRequired,
1206    MaxMonthlySpendReached,
1207    Message(SharedString),
1208}
1209
1210#[derive(Debug, Clone)]
1211pub enum ThreadEvent {
1212    ShowError(ThreadError),
1213    StreamedCompletion,
1214    StreamedAssistantText(MessageId, String),
1215    DoneStreaming,
1216    MessageAdded(MessageId),
1217    MessageEdited(MessageId),
1218    MessageDeleted(MessageId),
1219    SummaryChanged,
1220    UsePendingTools,
1221    ToolFinished {
1222        #[allow(unused)]
1223        tool_use_id: LanguageModelToolUseId,
1224        /// The pending tool use that corresponds to this tool.
1225        pending_tool_use: Option<PendingToolUse>,
1226        /// Whether the tool was canceled by the user.
1227        canceled: bool,
1228    },
1229}
1230
1231impl EventEmitter<ThreadEvent> for Thread {}
1232
1233struct PendingCompletion {
1234    id: usize,
1235    _task: Task<()>,
1236}