thread.rs

   1use std::fmt::Write as _;
   2use std::io::Write;
   3use std::sync::Arc;
   4
   5use anyhow::{Context as _, Result};
   6use assistant_tool::{ActionLog, ToolWorkingSet};
   7use chrono::{DateTime, Utc};
   8use collections::{BTreeMap, HashMap, HashSet};
   9use futures::future::Shared;
  10use futures::{FutureExt, StreamExt as _};
  11use git;
  12use gpui::{App, AppContext, Context, Entity, EventEmitter, SharedString, Task};
  13use language_model::{
  14    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
  15    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
  16    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
  17    Role, StopReason, TokenUsage,
  18};
  19use project::Project;
  20use prompt_store::{AssistantSystemPromptWorktree, PromptBuilder};
  21use scripting_tool::{ScriptingSession, ScriptingTool};
  22use serde::{Deserialize, Serialize};
  23use util::{post_inc, ResultExt, TryFutureExt as _};
  24use uuid::Uuid;
  25
  26use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
  27use crate::thread_store::{
  28    SerializedMessage, SerializedThread, SerializedToolResult, SerializedToolUse,
  29};
  30use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
  31
  32#[derive(Debug, Clone, Copy)]
  33pub enum RequestKind {
  34    Chat,
  35    /// Used when summarizing a thread.
  36    Summarize,
  37}
  38
  39#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
  40pub struct ThreadId(Arc<str>);
  41
  42impl ThreadId {
  43    pub fn new() -> Self {
  44        Self(Uuid::new_v4().to_string().into())
  45    }
  46}
  47
  48impl std::fmt::Display for ThreadId {
  49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
  50        write!(f, "{}", self.0)
  51    }
  52}
  53
  54#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
  55pub struct MessageId(pub(crate) usize);
  56
  57impl MessageId {
  58    fn post_inc(&mut self) -> Self {
  59        Self(post_inc(&mut self.0))
  60    }
  61}
  62
  63/// A message in a [`Thread`].
  64#[derive(Debug, Clone)]
  65pub struct Message {
  66    pub id: MessageId,
  67    pub role: Role,
  68    pub text: String,
  69}
  70
  71#[derive(Debug, Clone, Serialize, Deserialize)]
  72pub struct ProjectSnapshot {
  73    pub worktree_snapshots: Vec<WorktreeSnapshot>,
  74    pub unsaved_buffer_paths: Vec<String>,
  75    pub timestamp: DateTime<Utc>,
  76}
  77
  78#[derive(Debug, Clone, Serialize, Deserialize)]
  79pub struct WorktreeSnapshot {
  80    pub worktree_path: String,
  81    pub git_state: Option<GitState>,
  82}
  83
  84#[derive(Debug, Clone, Serialize, Deserialize)]
  85pub struct GitState {
  86    pub remote_url: Option<String>,
  87    pub head_sha: Option<String>,
  88    pub current_branch: Option<String>,
  89    pub diff: Option<String>,
  90}
  91
  92/// A thread of conversation with the LLM.
  93pub struct Thread {
  94    id: ThreadId,
  95    updated_at: DateTime<Utc>,
  96    summary: Option<SharedString>,
  97    pending_summary: Task<Option<()>>,
  98    messages: Vec<Message>,
  99    next_message_id: MessageId,
 100    context: BTreeMap<ContextId, ContextSnapshot>,
 101    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 102    completion_count: usize,
 103    pending_completions: Vec<PendingCompletion>,
 104    project: Entity<Project>,
 105    prompt_builder: Arc<PromptBuilder>,
 106    tools: Arc<ToolWorkingSet>,
 107    tool_use: ToolUseState,
 108    action_log: Entity<ActionLog>,
 109    scripting_session: Entity<ScriptingSession>,
 110    scripting_tool_use: ToolUseState,
 111    initial_project_snapshot: Shared<Task<Option<Arc<ProjectSnapshot>>>>,
 112    cumulative_token_usage: TokenUsage,
 113}
 114
 115impl Thread {
 116    pub fn new(
 117        project: Entity<Project>,
 118        tools: Arc<ToolWorkingSet>,
 119        prompt_builder: Arc<PromptBuilder>,
 120        cx: &mut Context<Self>,
 121    ) -> Self {
 122        Self {
 123            id: ThreadId::new(),
 124            updated_at: Utc::now(),
 125            summary: None,
 126            pending_summary: Task::ready(None),
 127            messages: Vec::new(),
 128            next_message_id: MessageId(0),
 129            context: BTreeMap::default(),
 130            context_by_message: HashMap::default(),
 131            completion_count: 0,
 132            pending_completions: Vec::new(),
 133            project: project.clone(),
 134            prompt_builder,
 135            tools,
 136            tool_use: ToolUseState::new(),
 137            scripting_session: cx.new(|cx| ScriptingSession::new(project.clone(), cx)),
 138            scripting_tool_use: ToolUseState::new(),
 139            action_log: cx.new(|_| ActionLog::new()),
 140            initial_project_snapshot: {
 141                let project_snapshot = Self::project_snapshot(project, cx);
 142                cx.foreground_executor()
 143                    .spawn(async move { Some(project_snapshot.await) })
 144                    .shared()
 145            },
 146            cumulative_token_usage: TokenUsage::default(),
 147        }
 148    }
 149
 150    pub fn deserialize(
 151        id: ThreadId,
 152        serialized: SerializedThread,
 153        project: Entity<Project>,
 154        tools: Arc<ToolWorkingSet>,
 155        prompt_builder: Arc<PromptBuilder>,
 156        cx: &mut Context<Self>,
 157    ) -> Self {
 158        let next_message_id = MessageId(
 159            serialized
 160                .messages
 161                .last()
 162                .map(|message| message.id.0 + 1)
 163                .unwrap_or(0),
 164        );
 165        let tool_use = ToolUseState::from_serialized_messages(&serialized.messages, |name| {
 166            name != ScriptingTool::NAME
 167        });
 168        let scripting_tool_use =
 169            ToolUseState::from_serialized_messages(&serialized.messages, |name| {
 170                name == ScriptingTool::NAME
 171            });
 172        let scripting_session = cx.new(|cx| ScriptingSession::new(project.clone(), cx));
 173
 174        Self {
 175            id,
 176            updated_at: serialized.updated_at,
 177            summary: Some(serialized.summary),
 178            pending_summary: Task::ready(None),
 179            messages: serialized
 180                .messages
 181                .into_iter()
 182                .map(|message| Message {
 183                    id: message.id,
 184                    role: message.role,
 185                    text: message.text,
 186                })
 187                .collect(),
 188            next_message_id,
 189            context: BTreeMap::default(),
 190            context_by_message: HashMap::default(),
 191            completion_count: 0,
 192            pending_completions: Vec::new(),
 193            project,
 194            prompt_builder,
 195            tools,
 196            tool_use,
 197            action_log: cx.new(|_| ActionLog::new()),
 198            scripting_session,
 199            scripting_tool_use,
 200            initial_project_snapshot: Task::ready(serialized.initial_project_snapshot).shared(),
 201            // TODO: persist token usage?
 202            cumulative_token_usage: TokenUsage::default(),
 203        }
 204    }
 205
 206    pub fn id(&self) -> &ThreadId {
 207        &self.id
 208    }
 209
 210    pub fn is_empty(&self) -> bool {
 211        self.messages.is_empty()
 212    }
 213
 214    pub fn updated_at(&self) -> DateTime<Utc> {
 215        self.updated_at
 216    }
 217
 218    pub fn touch_updated_at(&mut self) {
 219        self.updated_at = Utc::now();
 220    }
 221
 222    pub fn summary(&self) -> Option<SharedString> {
 223        self.summary.clone()
 224    }
 225
 226    pub fn summary_or_default(&self) -> SharedString {
 227        const DEFAULT: SharedString = SharedString::new_static("New Thread");
 228        self.summary.clone().unwrap_or(DEFAULT)
 229    }
 230
 231    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
 232        self.summary = Some(summary.into());
 233        cx.emit(ThreadEvent::SummaryChanged);
 234    }
 235
 236    pub fn message(&self, id: MessageId) -> Option<&Message> {
 237        self.messages.iter().find(|message| message.id == id)
 238    }
 239
 240    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 241        self.messages.iter()
 242    }
 243
 244    pub fn is_generating(&self) -> bool {
 245        !self.pending_completions.is_empty() || !self.all_tools_finished()
 246    }
 247
 248    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 249        &self.tools
 250    }
 251
 252    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
 253        let context = self.context_by_message.get(&id)?;
 254        Some(
 255            context
 256                .into_iter()
 257                .filter_map(|context_id| self.context.get(&context_id))
 258                .cloned()
 259                .collect::<Vec<_>>(),
 260        )
 261    }
 262
 263    /// Returns whether all of the tool uses have finished running.
 264    pub fn all_tools_finished(&self) -> bool {
 265        let mut all_pending_tool_uses = self
 266            .tool_use
 267            .pending_tool_uses()
 268            .into_iter()
 269            .chain(self.scripting_tool_use.pending_tool_uses());
 270
 271        // If the only pending tool uses left are the ones with errors, then
 272        // that means that we've finished running all of the pending tools.
 273        all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
 274    }
 275
 276    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
 277        self.tool_use.tool_uses_for_message(id)
 278    }
 279
 280    pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
 281        self.scripting_tool_use.tool_uses_for_message(id)
 282    }
 283
 284    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
 285        self.tool_use.tool_results_for_message(id)
 286    }
 287
 288    pub fn tool_result(&self, id: &LanguageModelToolUseId) -> Option<&LanguageModelToolResult> {
 289        self.tool_use.tool_result(id)
 290    }
 291
 292    pub fn scripting_tool_results_for_message(
 293        &self,
 294        id: MessageId,
 295    ) -> Vec<&LanguageModelToolResult> {
 296        self.scripting_tool_use.tool_results_for_message(id)
 297    }
 298
 299    pub fn scripting_changed_buffers<'a>(
 300        &self,
 301        cx: &'a App,
 302    ) -> impl ExactSizeIterator<Item = &'a Entity<language::Buffer>> {
 303        self.scripting_session.read(cx).changed_buffers()
 304    }
 305
 306    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
 307        self.tool_use.message_has_tool_results(message_id)
 308    }
 309
 310    pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
 311        self.scripting_tool_use.message_has_tool_results(message_id)
 312    }
 313
 314    pub fn insert_user_message(
 315        &mut self,
 316        text: impl Into<String>,
 317        context: Vec<ContextSnapshot>,
 318        cx: &mut Context<Self>,
 319    ) -> MessageId {
 320        let message_id = self.insert_message(Role::User, text, cx);
 321        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
 322        self.context
 323            .extend(context.into_iter().map(|context| (context.id, context)));
 324        self.context_by_message.insert(message_id, context_ids);
 325        message_id
 326    }
 327
 328    pub fn insert_message(
 329        &mut self,
 330        role: Role,
 331        text: impl Into<String>,
 332        cx: &mut Context<Self>,
 333    ) -> MessageId {
 334        let id = self.next_message_id.post_inc();
 335        self.messages.push(Message {
 336            id,
 337            role,
 338            text: text.into(),
 339        });
 340        self.touch_updated_at();
 341        cx.emit(ThreadEvent::MessageAdded(id));
 342        id
 343    }
 344
 345    pub fn edit_message(
 346        &mut self,
 347        id: MessageId,
 348        new_role: Role,
 349        new_text: String,
 350        cx: &mut Context<Self>,
 351    ) -> bool {
 352        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
 353            return false;
 354        };
 355        message.role = new_role;
 356        message.text = new_text;
 357        self.touch_updated_at();
 358        cx.emit(ThreadEvent::MessageEdited(id));
 359        true
 360    }
 361
 362    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
 363        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
 364            return false;
 365        };
 366        self.messages.remove(index);
 367        self.context_by_message.remove(&id);
 368        self.touch_updated_at();
 369        cx.emit(ThreadEvent::MessageDeleted(id));
 370        true
 371    }
 372
 373    /// Returns the representation of this [`Thread`] in a textual form.
 374    ///
 375    /// This is the representation we use when attaching a thread as context to another thread.
 376    pub fn text(&self) -> String {
 377        let mut text = String::new();
 378
 379        for message in &self.messages {
 380            text.push_str(match message.role {
 381                language_model::Role::User => "User:",
 382                language_model::Role::Assistant => "Assistant:",
 383                language_model::Role::System => "System:",
 384            });
 385            text.push('\n');
 386
 387            text.push_str(&message.text);
 388            text.push('\n');
 389        }
 390
 391        text
 392    }
 393
 394    /// Serializes this thread into a format for storage or telemetry.
 395    pub fn serialize(&self, cx: &mut Context<Self>) -> Task<Result<SerializedThread>> {
 396        let initial_project_snapshot = self.initial_project_snapshot.clone();
 397        cx.spawn(|this, cx| async move {
 398            let initial_project_snapshot = initial_project_snapshot.await;
 399            this.read_with(&cx, |this, _| SerializedThread {
 400                summary: this.summary_or_default(),
 401                updated_at: this.updated_at(),
 402                messages: this
 403                    .messages()
 404                    .map(|message| SerializedMessage {
 405                        id: message.id,
 406                        role: message.role,
 407                        text: message.text.clone(),
 408                        tool_uses: this
 409                            .tool_uses_for_message(message.id)
 410                            .into_iter()
 411                            .chain(this.scripting_tool_uses_for_message(message.id))
 412                            .map(|tool_use| SerializedToolUse {
 413                                id: tool_use.id,
 414                                name: tool_use.name,
 415                                input: tool_use.input,
 416                            })
 417                            .collect(),
 418                        tool_results: this
 419                            .tool_results_for_message(message.id)
 420                            .into_iter()
 421                            .chain(this.scripting_tool_results_for_message(message.id))
 422                            .map(|tool_result| SerializedToolResult {
 423                                tool_use_id: tool_result.tool_use_id.clone(),
 424                                is_error: tool_result.is_error,
 425                                content: tool_result.content.clone(),
 426                            })
 427                            .collect(),
 428                    })
 429                    .collect(),
 430                initial_project_snapshot,
 431            })
 432        })
 433    }
 434
 435    pub fn send_to_model(
 436        &mut self,
 437        model: Arc<dyn LanguageModel>,
 438        request_kind: RequestKind,
 439        cx: &mut Context<Self>,
 440    ) {
 441        let mut request = self.to_completion_request(request_kind, cx);
 442        request.tools = {
 443            let mut tools = Vec::new();
 444
 445            if self.tools.is_scripting_tool_enabled() {
 446                tools.push(LanguageModelRequestTool {
 447                    name: ScriptingTool::NAME.into(),
 448                    description: ScriptingTool::DESCRIPTION.into(),
 449                    input_schema: ScriptingTool::input_schema(),
 450                });
 451            }
 452
 453            tools.extend(self.tools().enabled_tools(cx).into_iter().map(|tool| {
 454                LanguageModelRequestTool {
 455                    name: tool.name(),
 456                    description: tool.description(),
 457                    input_schema: tool.input_schema(),
 458                }
 459            }));
 460
 461            tools
 462        };
 463
 464        self.stream_completion(request, model, cx);
 465    }
 466
 467    pub fn to_completion_request(
 468        &self,
 469        request_kind: RequestKind,
 470        cx: &App,
 471    ) -> LanguageModelRequest {
 472        let worktree_root_names = self
 473            .project
 474            .read(cx)
 475            .visible_worktrees(cx)
 476            .map(|worktree| {
 477                let worktree = worktree.read(cx);
 478                AssistantSystemPromptWorktree {
 479                    root_name: worktree.root_name().into(),
 480                    abs_path: worktree.abs_path(),
 481                }
 482            })
 483            .collect::<Vec<_>>();
 484        let system_prompt = self
 485            .prompt_builder
 486            .generate_assistant_system_prompt(worktree_root_names)
 487            .context("failed to generate assistant system prompt")
 488            .log_err()
 489            .unwrap_or_default();
 490
 491        let mut request = LanguageModelRequest {
 492            messages: vec![LanguageModelRequestMessage {
 493                role: Role::System,
 494                content: vec![MessageContent::Text(system_prompt)],
 495                cache: true,
 496            }],
 497            tools: Vec::new(),
 498            stop: Vec::new(),
 499            temperature: None,
 500        };
 501
 502        let mut referenced_context_ids = HashSet::default();
 503
 504        for message in &self.messages {
 505            if let Some(context_ids) = self.context_by_message.get(&message.id) {
 506                referenced_context_ids.extend(context_ids);
 507            }
 508
 509            let mut request_message = LanguageModelRequestMessage {
 510                role: message.role,
 511                content: Vec::new(),
 512                cache: false,
 513            };
 514
 515            match request_kind {
 516                RequestKind::Chat => {
 517                    self.tool_use
 518                        .attach_tool_results(message.id, &mut request_message);
 519                    self.scripting_tool_use
 520                        .attach_tool_results(message.id, &mut request_message);
 521                }
 522                RequestKind::Summarize => {
 523                    // We don't care about tool use during summarization.
 524                }
 525            }
 526
 527            if !message.text.is_empty() {
 528                request_message
 529                    .content
 530                    .push(MessageContent::Text(message.text.clone()));
 531            }
 532
 533            match request_kind {
 534                RequestKind::Chat => {
 535                    self.tool_use
 536                        .attach_tool_uses(message.id, &mut request_message);
 537                    self.scripting_tool_use
 538                        .attach_tool_uses(message.id, &mut request_message);
 539                }
 540                RequestKind::Summarize => {
 541                    // We don't care about tool use during summarization.
 542                }
 543            };
 544
 545            request.messages.push(request_message);
 546        }
 547
 548        if !referenced_context_ids.is_empty() {
 549            let mut context_message = LanguageModelRequestMessage {
 550                role: Role::User,
 551                content: Vec::new(),
 552                cache: false,
 553            };
 554
 555            let referenced_context = referenced_context_ids
 556                .into_iter()
 557                .filter_map(|context_id| self.context.get(context_id))
 558                .cloned();
 559            attach_context_to_message(&mut context_message, referenced_context);
 560
 561            request.messages.push(context_message);
 562        }
 563
 564        self.attach_stale_files(&mut request.messages, cx);
 565
 566        request
 567    }
 568
 569    fn attach_stale_files(&self, messages: &mut Vec<LanguageModelRequestMessage>, cx: &App) {
 570        const STALE_FILES_HEADER: &str = "These files changed since last read:";
 571
 572        let mut stale_message = String::new();
 573
 574        for stale_file in self.action_log.read(cx).stale_buffers(cx) {
 575            let Some(file) = stale_file.read(cx).file() else {
 576                continue;
 577            };
 578
 579            if stale_message.is_empty() {
 580                write!(&mut stale_message, "{}", STALE_FILES_HEADER).ok();
 581            }
 582
 583            writeln!(&mut stale_message, "- {}", file.path().display()).ok();
 584        }
 585
 586        if !stale_message.is_empty() {
 587            let context_message = LanguageModelRequestMessage {
 588                role: Role::User,
 589                content: vec![stale_message.into()],
 590                cache: false,
 591            };
 592
 593            messages.push(context_message);
 594        }
 595    }
 596
 597    pub fn stream_completion(
 598        &mut self,
 599        request: LanguageModelRequest,
 600        model: Arc<dyn LanguageModel>,
 601        cx: &mut Context<Self>,
 602    ) {
 603        let pending_completion_id = post_inc(&mut self.completion_count);
 604
 605        let task = cx.spawn(|thread, mut cx| async move {
 606            let stream = model.stream_completion(request, &cx);
 607            let stream_completion = async {
 608                let mut events = stream.await?;
 609                let mut stop_reason = StopReason::EndTurn;
 610                let mut current_token_usage = TokenUsage::default();
 611
 612                while let Some(event) = events.next().await {
 613                    let event = event?;
 614
 615                    thread.update(&mut cx, |thread, cx| {
 616                        match event {
 617                            LanguageModelCompletionEvent::StartMessage { .. } => {
 618                                thread.insert_message(Role::Assistant, String::new(), cx);
 619                            }
 620                            LanguageModelCompletionEvent::Stop(reason) => {
 621                                stop_reason = reason;
 622                            }
 623                            LanguageModelCompletionEvent::UsageUpdate(token_usage) => {
 624                                thread.cumulative_token_usage =
 625                                    thread.cumulative_token_usage.clone() + token_usage.clone()
 626                                        - current_token_usage.clone();
 627                                current_token_usage = token_usage;
 628                            }
 629                            LanguageModelCompletionEvent::Text(chunk) => {
 630                                if let Some(last_message) = thread.messages.last_mut() {
 631                                    if last_message.role == Role::Assistant {
 632                                        last_message.text.push_str(&chunk);
 633                                        cx.emit(ThreadEvent::StreamedAssistantText(
 634                                            last_message.id,
 635                                            chunk,
 636                                        ));
 637                                    } else {
 638                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
 639                                        // of a new Assistant response.
 640                                        //
 641                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
 642                                        // will result in duplicating the text of the chunk in the rendered Markdown.
 643                                        thread.insert_message(Role::Assistant, chunk, cx);
 644                                    };
 645                                }
 646                            }
 647                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
 648                                if let Some(last_assistant_message) = thread
 649                                    .messages
 650                                    .iter()
 651                                    .rfind(|message| message.role == Role::Assistant)
 652                                {
 653                                    if tool_use.name.as_ref() == ScriptingTool::NAME {
 654                                        thread
 655                                            .scripting_tool_use
 656                                            .request_tool_use(last_assistant_message.id, tool_use);
 657                                    } else {
 658                                        thread
 659                                            .tool_use
 660                                            .request_tool_use(last_assistant_message.id, tool_use);
 661                                    }
 662                                }
 663                            }
 664                        }
 665
 666                        thread.touch_updated_at();
 667                        cx.emit(ThreadEvent::StreamedCompletion);
 668                        cx.notify();
 669                    })?;
 670
 671                    smol::future::yield_now().await;
 672                }
 673
 674                thread.update(&mut cx, |thread, cx| {
 675                    thread
 676                        .pending_completions
 677                        .retain(|completion| completion.id != pending_completion_id);
 678
 679                    if thread.summary.is_none() && thread.messages.len() >= 2 {
 680                        thread.summarize(cx);
 681                    }
 682                })?;
 683
 684                anyhow::Ok(stop_reason)
 685            };
 686
 687            let result = stream_completion.await;
 688
 689            thread
 690                .update(&mut cx, |thread, cx| {
 691                    match result.as_ref() {
 692                        Ok(stop_reason) => match stop_reason {
 693                            StopReason::ToolUse => {
 694                                cx.emit(ThreadEvent::UsePendingTools);
 695                            }
 696                            StopReason::EndTurn => {}
 697                            StopReason::MaxTokens => {}
 698                        },
 699                        Err(error) => {
 700                            if error.is::<PaymentRequiredError>() {
 701                                cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
 702                            } else if error.is::<MaxMonthlySpendReachedError>() {
 703                                cx.emit(ThreadEvent::ShowError(
 704                                    ThreadError::MaxMonthlySpendReached,
 705                                ));
 706                            } else {
 707                                let error_message = error
 708                                    .chain()
 709                                    .map(|err| err.to_string())
 710                                    .collect::<Vec<_>>()
 711                                    .join("\n");
 712                                cx.emit(ThreadEvent::ShowError(ThreadError::Message(
 713                                    SharedString::from(error_message.clone()),
 714                                )));
 715                            }
 716
 717                            thread.cancel_last_completion(cx);
 718                        }
 719                    }
 720                    cx.emit(ThreadEvent::DoneStreaming);
 721                })
 722                .ok();
 723        });
 724
 725        self.pending_completions.push(PendingCompletion {
 726            id: pending_completion_id,
 727            _task: task,
 728        });
 729    }
 730
 731    pub fn summarize(&mut self, cx: &mut Context<Self>) {
 732        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
 733            return;
 734        };
 735        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
 736            return;
 737        };
 738
 739        if !provider.is_authenticated(cx) {
 740            return;
 741        }
 742
 743        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
 744        request.messages.push(LanguageModelRequestMessage {
 745            role: Role::User,
 746            content: vec![
 747                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
 748                    .into(),
 749            ],
 750            cache: false,
 751        });
 752
 753        self.pending_summary = cx.spawn(|this, mut cx| {
 754            async move {
 755                let stream = model.stream_completion_text(request, &cx);
 756                let mut messages = stream.await?;
 757
 758                let mut new_summary = String::new();
 759                while let Some(message) = messages.stream.next().await {
 760                    let text = message?;
 761                    let mut lines = text.lines();
 762                    new_summary.extend(lines.next());
 763
 764                    // Stop if the LLM generated multiple lines.
 765                    if lines.next().is_some() {
 766                        break;
 767                    }
 768                }
 769
 770                this.update(&mut cx, |this, cx| {
 771                    if !new_summary.is_empty() {
 772                        this.summary = Some(new_summary.into());
 773                    }
 774
 775                    cx.emit(ThreadEvent::SummaryChanged);
 776                })?;
 777
 778                anyhow::Ok(())
 779            }
 780            .log_err()
 781        });
 782    }
 783
 784    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
 785        let request = self.to_completion_request(RequestKind::Chat, cx);
 786        let pending_tool_uses = self
 787            .tool_use
 788            .pending_tool_uses()
 789            .into_iter()
 790            .filter(|tool_use| tool_use.status.is_idle())
 791            .cloned()
 792            .collect::<Vec<_>>();
 793
 794        for tool_use in pending_tool_uses {
 795            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
 796                let task = tool.run(
 797                    tool_use.input,
 798                    &request.messages,
 799                    self.project.clone(),
 800                    self.action_log.clone(),
 801                    cx,
 802                );
 803
 804                self.insert_tool_output(tool_use.id.clone(), task, cx);
 805            }
 806        }
 807
 808        let pending_scripting_tool_uses = self
 809            .scripting_tool_use
 810            .pending_tool_uses()
 811            .into_iter()
 812            .filter(|tool_use| tool_use.status.is_idle())
 813            .cloned()
 814            .collect::<Vec<_>>();
 815
 816        for scripting_tool_use in pending_scripting_tool_uses {
 817            let task = match ScriptingTool::deserialize_input(scripting_tool_use.input) {
 818                Err(err) => Task::ready(Err(err.into())),
 819                Ok(input) => {
 820                    let (script_id, script_task) =
 821                        self.scripting_session.update(cx, move |session, cx| {
 822                            session.run_script(input.lua_script, cx)
 823                        });
 824
 825                    let session = self.scripting_session.clone();
 826                    cx.spawn(|_, cx| async move {
 827                        script_task.await;
 828
 829                        let message = session.read_with(&cx, |session, _cx| {
 830                            // Using a id to get the script output seems impractical.
 831                            // Why not just include it in the Task result?
 832                            // This is because we'll later report the script state as it runs,
 833                            session
 834                                .get(script_id)
 835                                .output_message_for_llm()
 836                                .expect("Script shouldn't still be running")
 837                        })?;
 838
 839                        Ok(message)
 840                    })
 841                }
 842            };
 843
 844            self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
 845        }
 846    }
 847
 848    pub fn insert_tool_output(
 849        &mut self,
 850        tool_use_id: LanguageModelToolUseId,
 851        output: Task<Result<String>>,
 852        cx: &mut Context<Self>,
 853    ) {
 854        let insert_output_task = cx.spawn(|thread, mut cx| {
 855            let tool_use_id = tool_use_id.clone();
 856            async move {
 857                let output = output.await;
 858                thread
 859                    .update(&mut cx, |thread, cx| {
 860                        let pending_tool_use = thread
 861                            .tool_use
 862                            .insert_tool_output(tool_use_id.clone(), output);
 863
 864                        cx.emit(ThreadEvent::ToolFinished {
 865                            tool_use_id,
 866                            pending_tool_use,
 867                            canceled: false,
 868                        });
 869                    })
 870                    .ok();
 871            }
 872        });
 873
 874        self.tool_use
 875            .run_pending_tool(tool_use_id, insert_output_task);
 876    }
 877
 878    pub fn insert_scripting_tool_output(
 879        &mut self,
 880        tool_use_id: LanguageModelToolUseId,
 881        output: Task<Result<String>>,
 882        cx: &mut Context<Self>,
 883    ) {
 884        let insert_output_task = cx.spawn(|thread, mut cx| {
 885            let tool_use_id = tool_use_id.clone();
 886            async move {
 887                let output = output.await;
 888                thread
 889                    .update(&mut cx, |thread, cx| {
 890                        let pending_tool_use = thread
 891                            .scripting_tool_use
 892                            .insert_tool_output(tool_use_id.clone(), output);
 893
 894                        cx.emit(ThreadEvent::ToolFinished {
 895                            tool_use_id,
 896                            pending_tool_use,
 897                            canceled: false,
 898                        });
 899                    })
 900                    .ok();
 901            }
 902        });
 903
 904        self.scripting_tool_use
 905            .run_pending_tool(tool_use_id, insert_output_task);
 906    }
 907
 908    pub fn attach_tool_results(
 909        &mut self,
 910        updated_context: Vec<ContextSnapshot>,
 911        cx: &mut Context<Self>,
 912    ) {
 913        self.context.extend(
 914            updated_context
 915                .into_iter()
 916                .map(|context| (context.id, context)),
 917        );
 918
 919        // Insert a user message to contain the tool results.
 920        self.insert_user_message(
 921            // TODO: Sending up a user message without any content results in the model sending back
 922            // responses that also don't have any content. We currently don't handle this case well,
 923            // so for now we provide some text to keep the model on track.
 924            "Here are the tool results.",
 925            Vec::new(),
 926            cx,
 927        );
 928    }
 929
 930    /// Cancels the last pending completion, if there are any pending.
 931    ///
 932    /// Returns whether a completion was canceled.
 933    pub fn cancel_last_completion(&mut self, cx: &mut Context<Self>) -> bool {
 934        if self.pending_completions.pop().is_some() {
 935            true
 936        } else {
 937            let mut canceled = false;
 938            for pending_tool_use in self.tool_use.cancel_pending() {
 939                canceled = true;
 940                cx.emit(ThreadEvent::ToolFinished {
 941                    tool_use_id: pending_tool_use.id.clone(),
 942                    pending_tool_use: Some(pending_tool_use),
 943                    canceled: true,
 944                });
 945            }
 946            canceled
 947        }
 948    }
 949
 950    /// Reports feedback about the thread and stores it in our telemetry backend.
 951    pub fn report_feedback(&self, is_positive: bool, cx: &mut Context<Self>) -> Task<Result<()>> {
 952        let final_project_snapshot = Self::project_snapshot(self.project.clone(), cx);
 953        let serialized_thread = self.serialize(cx);
 954        let thread_id = self.id().clone();
 955        let client = self.project.read(cx).client();
 956
 957        cx.background_spawn(async move {
 958            let final_project_snapshot = final_project_snapshot.await;
 959            let serialized_thread = serialized_thread.await?;
 960            let thread_data =
 961                serde_json::to_value(serialized_thread).unwrap_or_else(|_| serde_json::Value::Null);
 962
 963            let rating = if is_positive { "positive" } else { "negative" };
 964            telemetry::event!(
 965                "Assistant Thread Rated",
 966                rating,
 967                thread_id,
 968                thread_data,
 969                final_project_snapshot
 970            );
 971            client.telemetry().flush_events();
 972
 973            Ok(())
 974        })
 975    }
 976
 977    /// Create a snapshot of the current project state including git information and unsaved buffers.
 978    fn project_snapshot(
 979        project: Entity<Project>,
 980        cx: &mut Context<Self>,
 981    ) -> Task<Arc<ProjectSnapshot>> {
 982        let worktree_snapshots: Vec<_> = project
 983            .read(cx)
 984            .visible_worktrees(cx)
 985            .map(|worktree| Self::worktree_snapshot(worktree, cx))
 986            .collect();
 987
 988        cx.spawn(move |_, cx| async move {
 989            let worktree_snapshots = futures::future::join_all(worktree_snapshots).await;
 990
 991            let mut unsaved_buffers = Vec::new();
 992            cx.update(|app_cx| {
 993                let buffer_store = project.read(app_cx).buffer_store();
 994                for buffer_handle in buffer_store.read(app_cx).buffers() {
 995                    let buffer = buffer_handle.read(app_cx);
 996                    if buffer.is_dirty() {
 997                        if let Some(file) = buffer.file() {
 998                            let path = file.path().to_string_lossy().to_string();
 999                            unsaved_buffers.push(path);
1000                        }
1001                    }
1002                }
1003            })
1004            .ok();
1005
1006            Arc::new(ProjectSnapshot {
1007                worktree_snapshots,
1008                unsaved_buffer_paths: unsaved_buffers,
1009                timestamp: Utc::now(),
1010            })
1011        })
1012    }
1013
1014    fn worktree_snapshot(worktree: Entity<project::Worktree>, cx: &App) -> Task<WorktreeSnapshot> {
1015        cx.spawn(move |cx| async move {
1016            // Get worktree path and snapshot
1017            let worktree_info = cx.update(|app_cx| {
1018                let worktree = worktree.read(app_cx);
1019                let path = worktree.abs_path().to_string_lossy().to_string();
1020                let snapshot = worktree.snapshot();
1021                (path, snapshot)
1022            });
1023
1024            let Ok((worktree_path, snapshot)) = worktree_info else {
1025                return WorktreeSnapshot {
1026                    worktree_path: String::new(),
1027                    git_state: None,
1028                };
1029            };
1030
1031            // Extract git information
1032            let git_state = match snapshot.repositories().first() {
1033                None => None,
1034                Some(repo_entry) => {
1035                    // Get branch information
1036                    let current_branch = repo_entry.branch().map(|branch| branch.name.to_string());
1037
1038                    // Get repository info
1039                    let repo_result = worktree.read_with(&cx, |worktree, _cx| {
1040                        if let project::Worktree::Local(local_worktree) = &worktree {
1041                            local_worktree.get_local_repo(repo_entry).map(|local_repo| {
1042                                let repo = local_repo.repo();
1043                                (repo.remote_url("origin"), repo.head_sha(), repo.clone())
1044                            })
1045                        } else {
1046                            None
1047                        }
1048                    });
1049
1050                    match repo_result {
1051                        Ok(Some((remote_url, head_sha, repository))) => {
1052                            // Get diff asynchronously
1053                            let diff = repository
1054                                .diff(git::repository::DiffType::HeadToWorktree, cx)
1055                                .await
1056                                .ok();
1057
1058                            Some(GitState {
1059                                remote_url,
1060                                head_sha,
1061                                current_branch,
1062                                diff,
1063                            })
1064                        }
1065                        Err(_) | Ok(None) => None,
1066                    }
1067                }
1068            };
1069
1070            WorktreeSnapshot {
1071                worktree_path,
1072                git_state,
1073            }
1074        })
1075    }
1076
1077    pub fn to_markdown(&self) -> Result<String> {
1078        let mut markdown = Vec::new();
1079
1080        if let Some(summary) = self.summary() {
1081            writeln!(markdown, "# {summary}\n")?;
1082        };
1083
1084        for message in self.messages() {
1085            writeln!(
1086                markdown,
1087                "## {role}\n",
1088                role = match message.role {
1089                    Role::User => "User",
1090                    Role::Assistant => "Assistant",
1091                    Role::System => "System",
1092                }
1093            )?;
1094            writeln!(markdown, "{}\n", message.text)?;
1095
1096            for tool_use in self.tool_uses_for_message(message.id) {
1097                writeln!(
1098                    markdown,
1099                    "**Use Tool: {} ({})**",
1100                    tool_use.name, tool_use.id
1101                )?;
1102                writeln!(markdown, "```json")?;
1103                writeln!(
1104                    markdown,
1105                    "{}",
1106                    serde_json::to_string_pretty(&tool_use.input)?
1107                )?;
1108                writeln!(markdown, "```")?;
1109            }
1110
1111            for tool_result in self.tool_results_for_message(message.id) {
1112                write!(markdown, "**Tool Results: {}", tool_result.tool_use_id)?;
1113                if tool_result.is_error {
1114                    write!(markdown, " (Error)")?;
1115                }
1116
1117                writeln!(markdown, "**\n")?;
1118                writeln!(markdown, "{}", tool_result.content)?;
1119            }
1120        }
1121
1122        Ok(String::from_utf8_lossy(&markdown).to_string())
1123    }
1124
1125    pub fn action_log(&self) -> &Entity<ActionLog> {
1126        &self.action_log
1127    }
1128
1129    pub fn cumulative_token_usage(&self) -> TokenUsage {
1130        self.cumulative_token_usage.clone()
1131    }
1132}
1133
1134#[derive(Debug, Clone)]
1135pub enum ThreadError {
1136    PaymentRequired,
1137    MaxMonthlySpendReached,
1138    Message(SharedString),
1139}
1140
1141#[derive(Debug, Clone)]
1142pub enum ThreadEvent {
1143    ShowError(ThreadError),
1144    StreamedCompletion,
1145    StreamedAssistantText(MessageId, String),
1146    DoneStreaming,
1147    MessageAdded(MessageId),
1148    MessageEdited(MessageId),
1149    MessageDeleted(MessageId),
1150    SummaryChanged,
1151    UsePendingTools,
1152    ToolFinished {
1153        #[allow(unused)]
1154        tool_use_id: LanguageModelToolUseId,
1155        /// The pending tool use that corresponds to this tool.
1156        pending_tool_use: Option<PendingToolUse>,
1157        /// Whether the tool was canceled by the user.
1158        canceled: bool,
1159    },
1160}
1161
1162impl EventEmitter<ThreadEvent> for Thread {}
1163
1164struct PendingCompletion {
1165    id: usize,
1166    _task: Task<()>,
1167}