thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use chrono::{DateTime, Utc};
  6use collections::{BTreeMap, HashMap, HashSet};
  7use futures::StreamExt as _;
  8use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
 11    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
 12    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
 13    Role, StopReason,
 14};
 15use project::Project;
 16use scripting_tool::ScriptingTool;
 17use serde::{Deserialize, Serialize};
 18use util::{post_inc, TryFutureExt as _};
 19use uuid::Uuid;
 20
 21use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 22use crate::thread_store::SavedThread;
 23use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
 24
 25#[derive(Debug, Clone, Copy)]
 26pub enum RequestKind {
 27    Chat,
 28    /// Used when summarizing a thread.
 29    Summarize,
 30}
 31
 32#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 33pub struct ThreadId(Arc<str>);
 34
 35impl ThreadId {
 36    pub fn new() -> Self {
 37        Self(Uuid::new_v4().to_string().into())
 38    }
 39}
 40
 41impl std::fmt::Display for ThreadId {
 42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 43        write!(f, "{}", self.0)
 44    }
 45}
 46
 47#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 48pub struct MessageId(pub(crate) usize);
 49
 50impl MessageId {
 51    fn post_inc(&mut self) -> Self {
 52        Self(post_inc(&mut self.0))
 53    }
 54}
 55
 56/// A message in a [`Thread`].
 57#[derive(Debug, Clone)]
 58pub struct Message {
 59    pub id: MessageId,
 60    pub role: Role,
 61    pub text: String,
 62}
 63
 64/// A thread of conversation with the LLM.
 65pub struct Thread {
 66    id: ThreadId,
 67    updated_at: DateTime<Utc>,
 68    summary: Option<SharedString>,
 69    pending_summary: Task<Option<()>>,
 70    messages: Vec<Message>,
 71    next_message_id: MessageId,
 72    context: BTreeMap<ContextId, ContextSnapshot>,
 73    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 74    completion_count: usize,
 75    pending_completions: Vec<PendingCompletion>,
 76    project: Entity<Project>,
 77    tools: Arc<ToolWorkingSet>,
 78    tool_use: ToolUseState,
 79    scripting_tool_use: ToolUseState,
 80}
 81
 82impl Thread {
 83    pub fn new(
 84        project: Entity<Project>,
 85        tools: Arc<ToolWorkingSet>,
 86        _cx: &mut Context<Self>,
 87    ) -> Self {
 88        Self {
 89            id: ThreadId::new(),
 90            updated_at: Utc::now(),
 91            summary: None,
 92            pending_summary: Task::ready(None),
 93            messages: Vec::new(),
 94            next_message_id: MessageId(0),
 95            context: BTreeMap::default(),
 96            context_by_message: HashMap::default(),
 97            completion_count: 0,
 98            pending_completions: Vec::new(),
 99            project,
100            tools,
101            tool_use: ToolUseState::new(),
102            scripting_tool_use: ToolUseState::new(),
103        }
104    }
105
106    pub fn from_saved(
107        id: ThreadId,
108        saved: SavedThread,
109        project: Entity<Project>,
110        tools: Arc<ToolWorkingSet>,
111        _cx: &mut Context<Self>,
112    ) -> Self {
113        let next_message_id = MessageId(
114            saved
115                .messages
116                .last()
117                .map(|message| message.id.0 + 1)
118                .unwrap_or(0),
119        );
120        let tool_use =
121            ToolUseState::from_saved_messages(&saved.messages, |name| name != ScriptingTool::NAME);
122        let scripting_tool_use =
123            ToolUseState::from_saved_messages(&saved.messages, |name| name == ScriptingTool::NAME);
124
125        Self {
126            id,
127            updated_at: saved.updated_at,
128            summary: Some(saved.summary),
129            pending_summary: Task::ready(None),
130            messages: saved
131                .messages
132                .into_iter()
133                .map(|message| Message {
134                    id: message.id,
135                    role: message.role,
136                    text: message.text,
137                })
138                .collect(),
139            next_message_id,
140            context: BTreeMap::default(),
141            context_by_message: HashMap::default(),
142            completion_count: 0,
143            pending_completions: Vec::new(),
144            project,
145            tools,
146            tool_use,
147            scripting_tool_use,
148        }
149    }
150
151    pub fn id(&self) -> &ThreadId {
152        &self.id
153    }
154
155    pub fn is_empty(&self) -> bool {
156        self.messages.is_empty()
157    }
158
159    pub fn updated_at(&self) -> DateTime<Utc> {
160        self.updated_at
161    }
162
163    pub fn touch_updated_at(&mut self) {
164        self.updated_at = Utc::now();
165    }
166
167    pub fn summary(&self) -> Option<SharedString> {
168        self.summary.clone()
169    }
170
171    pub fn summary_or_default(&self) -> SharedString {
172        const DEFAULT: SharedString = SharedString::new_static("New Thread");
173        self.summary.clone().unwrap_or(DEFAULT)
174    }
175
176    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
177        self.summary = Some(summary.into());
178        cx.emit(ThreadEvent::SummaryChanged);
179    }
180
181    pub fn message(&self, id: MessageId) -> Option<&Message> {
182        self.messages.iter().find(|message| message.id == id)
183    }
184
185    pub fn messages(&self) -> impl Iterator<Item = &Message> {
186        self.messages.iter()
187    }
188
189    pub fn is_streaming(&self) -> bool {
190        !self.pending_completions.is_empty()
191    }
192
193    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
194        &self.tools
195    }
196
197    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
198        let context = self.context_by_message.get(&id)?;
199        Some(
200            context
201                .into_iter()
202                .filter_map(|context_id| self.context.get(&context_id))
203                .cloned()
204                .collect::<Vec<_>>(),
205        )
206    }
207
208    /// Returns whether all of the tool uses have finished running.
209    pub fn all_tools_finished(&self) -> bool {
210        let mut all_pending_tool_uses = self
211            .tool_use
212            .pending_tool_uses()
213            .into_iter()
214            .chain(self.scripting_tool_use.pending_tool_uses());
215
216        // If the only pending tool uses left are the ones with errors, then that means that we've finished running all
217        // of the pending tools.
218        all_pending_tool_uses.all(|tool_use| tool_use.status.is_error())
219    }
220
221    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
222        self.tool_use.tool_uses_for_message(id)
223    }
224
225    pub fn scripting_tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
226        self.scripting_tool_use.tool_uses_for_message(id)
227    }
228
229    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
230        self.tool_use.tool_results_for_message(id)
231    }
232
233    pub fn scripting_tool_results_for_message(
234        &self,
235        id: MessageId,
236    ) -> Vec<&LanguageModelToolResult> {
237        self.scripting_tool_use.tool_results_for_message(id)
238    }
239
240    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
241        self.tool_use.message_has_tool_results(message_id)
242    }
243
244    pub fn message_has_scripting_tool_results(&self, message_id: MessageId) -> bool {
245        self.scripting_tool_use.message_has_tool_results(message_id)
246    }
247
248    pub fn insert_user_message(
249        &mut self,
250        text: impl Into<String>,
251        context: Vec<ContextSnapshot>,
252        cx: &mut Context<Self>,
253    ) -> MessageId {
254        let message_id = self.insert_message(Role::User, text, cx);
255        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
256        self.context
257            .extend(context.into_iter().map(|context| (context.id, context)));
258        self.context_by_message.insert(message_id, context_ids);
259        message_id
260    }
261
262    pub fn insert_message(
263        &mut self,
264        role: Role,
265        text: impl Into<String>,
266        cx: &mut Context<Self>,
267    ) -> MessageId {
268        let id = self.next_message_id.post_inc();
269        self.messages.push(Message {
270            id,
271            role,
272            text: text.into(),
273        });
274        self.touch_updated_at();
275        cx.emit(ThreadEvent::MessageAdded(id));
276        id
277    }
278
279    pub fn edit_message(
280        &mut self,
281        id: MessageId,
282        new_role: Role,
283        new_text: String,
284        cx: &mut Context<Self>,
285    ) -> bool {
286        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
287            return false;
288        };
289        message.role = new_role;
290        message.text = new_text;
291        self.touch_updated_at();
292        cx.emit(ThreadEvent::MessageEdited(id));
293        true
294    }
295
296    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
297        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
298            return false;
299        };
300        self.messages.remove(index);
301        self.context_by_message.remove(&id);
302        self.touch_updated_at();
303        cx.emit(ThreadEvent::MessageDeleted(id));
304        true
305    }
306
307    /// Returns the representation of this [`Thread`] in a textual form.
308    ///
309    /// This is the representation we use when attaching a thread as context to another thread.
310    pub fn text(&self) -> String {
311        let mut text = String::new();
312
313        for message in &self.messages {
314            text.push_str(match message.role {
315                language_model::Role::User => "User:",
316                language_model::Role::Assistant => "Assistant:",
317                language_model::Role::System => "System:",
318            });
319            text.push('\n');
320
321            text.push_str(&message.text);
322            text.push('\n');
323        }
324
325        text
326    }
327
328    pub fn send_to_model(
329        &mut self,
330        model: Arc<dyn LanguageModel>,
331        request_kind: RequestKind,
332        use_tools: bool,
333        cx: &mut Context<Self>,
334    ) {
335        let mut request = self.to_completion_request(request_kind, cx);
336
337        if use_tools {
338            let mut tools = Vec::new();
339            tools.push(LanguageModelRequestTool {
340                name: ScriptingTool::NAME.into(),
341                description: ScriptingTool::DESCRIPTION.into(),
342                input_schema: ScriptingTool::input_schema(),
343            });
344
345            tools.extend(
346                self.tools()
347                    .tools(cx)
348                    .into_iter()
349                    .map(|tool| LanguageModelRequestTool {
350                        name: tool.name(),
351                        description: tool.description(),
352                        input_schema: tool.input_schema(),
353                    }),
354            );
355
356            request.tools = tools;
357        }
358
359        self.stream_completion(request, model, cx);
360    }
361
362    pub fn to_completion_request(
363        &self,
364        request_kind: RequestKind,
365        _cx: &App,
366    ) -> LanguageModelRequest {
367        let mut request = LanguageModelRequest {
368            messages: vec![],
369            tools: Vec::new(),
370            stop: Vec::new(),
371            temperature: None,
372        };
373
374        let mut referenced_context_ids = HashSet::default();
375
376        for message in &self.messages {
377            if let Some(context_ids) = self.context_by_message.get(&message.id) {
378                referenced_context_ids.extend(context_ids);
379            }
380
381            let mut request_message = LanguageModelRequestMessage {
382                role: message.role,
383                content: Vec::new(),
384                cache: false,
385            };
386
387            match request_kind {
388                RequestKind::Chat => {
389                    self.tool_use
390                        .attach_tool_results(message.id, &mut request_message);
391                    self.scripting_tool_use
392                        .attach_tool_results(message.id, &mut request_message);
393                }
394                RequestKind::Summarize => {
395                    // We don't care about tool use during summarization.
396                }
397            }
398
399            if !message.text.is_empty() {
400                request_message
401                    .content
402                    .push(MessageContent::Text(message.text.clone()));
403            }
404
405            match request_kind {
406                RequestKind::Chat => {
407                    self.tool_use
408                        .attach_tool_uses(message.id, &mut request_message);
409                    self.scripting_tool_use
410                        .attach_tool_uses(message.id, &mut request_message);
411                }
412                RequestKind::Summarize => {
413                    // We don't care about tool use during summarization.
414                }
415            };
416
417            request.messages.push(request_message);
418        }
419
420        if !referenced_context_ids.is_empty() {
421            let mut context_message = LanguageModelRequestMessage {
422                role: Role::User,
423                content: Vec::new(),
424                cache: false,
425            };
426
427            let referenced_context = referenced_context_ids
428                .into_iter()
429                .filter_map(|context_id| self.context.get(context_id))
430                .cloned();
431            attach_context_to_message(&mut context_message, referenced_context);
432
433            request.messages.push(context_message);
434        }
435
436        request
437    }
438
439    pub fn stream_completion(
440        &mut self,
441        request: LanguageModelRequest,
442        model: Arc<dyn LanguageModel>,
443        cx: &mut Context<Self>,
444    ) {
445        let pending_completion_id = post_inc(&mut self.completion_count);
446
447        let task = cx.spawn(|thread, mut cx| async move {
448            let stream = model.stream_completion(request, &cx);
449            let stream_completion = async {
450                let mut events = stream.await?;
451                let mut stop_reason = StopReason::EndTurn;
452
453                while let Some(event) = events.next().await {
454                    let event = event?;
455
456                    thread.update(&mut cx, |thread, cx| {
457                        match event {
458                            LanguageModelCompletionEvent::StartMessage { .. } => {
459                                thread.insert_message(Role::Assistant, String::new(), cx);
460                            }
461                            LanguageModelCompletionEvent::Stop(reason) => {
462                                stop_reason = reason;
463                            }
464                            LanguageModelCompletionEvent::Text(chunk) => {
465                                if let Some(last_message) = thread.messages.last_mut() {
466                                    if last_message.role == Role::Assistant {
467                                        last_message.text.push_str(&chunk);
468                                        cx.emit(ThreadEvent::StreamedAssistantText(
469                                            last_message.id,
470                                            chunk,
471                                        ));
472                                    } else {
473                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
474                                        // of a new Assistant response.
475                                        //
476                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
477                                        // will result in duplicating the text of the chunk in the rendered Markdown.
478                                        thread.insert_message(Role::Assistant, chunk, cx);
479                                    };
480                                }
481                            }
482                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
483                                if let Some(last_assistant_message) = thread
484                                    .messages
485                                    .iter()
486                                    .rfind(|message| message.role == Role::Assistant)
487                                {
488                                    if tool_use.name.as_ref() == ScriptingTool::NAME {
489                                        thread
490                                            .scripting_tool_use
491                                            .request_tool_use(last_assistant_message.id, tool_use);
492                                    } else {
493                                        thread
494                                            .tool_use
495                                            .request_tool_use(last_assistant_message.id, tool_use);
496                                    }
497                                }
498                            }
499                        }
500
501                        thread.touch_updated_at();
502                        cx.emit(ThreadEvent::StreamedCompletion);
503                        cx.notify();
504                    })?;
505
506                    smol::future::yield_now().await;
507                }
508
509                thread.update(&mut cx, |thread, cx| {
510                    thread
511                        .pending_completions
512                        .retain(|completion| completion.id != pending_completion_id);
513
514                    if thread.summary.is_none() && thread.messages.len() >= 2 {
515                        thread.summarize(cx);
516                    }
517                })?;
518
519                anyhow::Ok(stop_reason)
520            };
521
522            let result = stream_completion.await;
523
524            thread
525                .update(&mut cx, |thread, cx| match result.as_ref() {
526                    Ok(stop_reason) => match stop_reason {
527                        StopReason::ToolUse => {
528                            cx.emit(ThreadEvent::UsePendingTools);
529                        }
530                        StopReason::EndTurn => {}
531                        StopReason::MaxTokens => {}
532                    },
533                    Err(error) => {
534                        if error.is::<PaymentRequiredError>() {
535                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
536                        } else if error.is::<MaxMonthlySpendReachedError>() {
537                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
538                        } else {
539                            let error_message = error
540                                .chain()
541                                .map(|err| err.to_string())
542                                .collect::<Vec<_>>()
543                                .join("\n");
544                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
545                                SharedString::from(error_message.clone()),
546                            )));
547                        }
548
549                        thread.cancel_last_completion();
550                    }
551                })
552                .ok();
553        });
554
555        self.pending_completions.push(PendingCompletion {
556            id: pending_completion_id,
557            _task: task,
558        });
559    }
560
561    pub fn summarize(&mut self, cx: &mut Context<Self>) {
562        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
563            return;
564        };
565        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
566            return;
567        };
568
569        if !provider.is_authenticated(cx) {
570            return;
571        }
572
573        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
574        request.messages.push(LanguageModelRequestMessage {
575            role: Role::User,
576            content: vec![
577                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
578                    .into(),
579            ],
580            cache: false,
581        });
582
583        self.pending_summary = cx.spawn(|this, mut cx| {
584            async move {
585                let stream = model.stream_completion_text(request, &cx);
586                let mut messages = stream.await?;
587
588                let mut new_summary = String::new();
589                while let Some(message) = messages.stream.next().await {
590                    let text = message?;
591                    let mut lines = text.lines();
592                    new_summary.extend(lines.next());
593
594                    // Stop if the LLM generated multiple lines.
595                    if lines.next().is_some() {
596                        break;
597                    }
598                }
599
600                this.update(&mut cx, |this, cx| {
601                    if !new_summary.is_empty() {
602                        this.summary = Some(new_summary.into());
603                    }
604
605                    cx.emit(ThreadEvent::SummaryChanged);
606                })?;
607
608                anyhow::Ok(())
609            }
610            .log_err()
611        });
612    }
613
614    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
615        let pending_tool_uses = self
616            .tool_use
617            .pending_tool_uses()
618            .into_iter()
619            .filter(|tool_use| tool_use.status.is_idle())
620            .cloned()
621            .collect::<Vec<_>>();
622
623        for tool_use in pending_tool_uses {
624            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
625                let task = tool.run(tool_use.input, self.project.clone(), cx);
626
627                self.insert_tool_output(tool_use.id.clone(), task, cx);
628            }
629        }
630
631        let pending_scripting_tool_uses = self
632            .scripting_tool_use
633            .pending_tool_uses()
634            .into_iter()
635            .filter(|tool_use| tool_use.status.is_idle())
636            .cloned()
637            .collect::<Vec<_>>();
638
639        for scripting_tool_use in pending_scripting_tool_uses {
640            let task = ScriptingTool.run(scripting_tool_use.input, self.project.clone(), cx);
641
642            self.insert_scripting_tool_output(scripting_tool_use.id.clone(), task, cx);
643        }
644    }
645
646    pub fn insert_tool_output(
647        &mut self,
648        tool_use_id: LanguageModelToolUseId,
649        output: Task<Result<String>>,
650        cx: &mut Context<Self>,
651    ) {
652        let insert_output_task = cx.spawn(|thread, mut cx| {
653            let tool_use_id = tool_use_id.clone();
654            async move {
655                let output = output.await;
656                thread
657                    .update(&mut cx, |thread, cx| {
658                        let pending_tool_use = thread
659                            .tool_use
660                            .insert_tool_output(tool_use_id.clone(), output);
661
662                        cx.emit(ThreadEvent::ToolFinished {
663                            tool_use_id,
664                            pending_tool_use,
665                        });
666                    })
667                    .ok();
668            }
669        });
670
671        self.tool_use
672            .run_pending_tool(tool_use_id, insert_output_task);
673    }
674
675    pub fn insert_scripting_tool_output(
676        &mut self,
677        tool_use_id: LanguageModelToolUseId,
678        output: Task<Result<String>>,
679        cx: &mut Context<Self>,
680    ) {
681        let insert_output_task = cx.spawn(|thread, mut cx| {
682            let tool_use_id = tool_use_id.clone();
683            async move {
684                let output = output.await;
685                thread
686                    .update(&mut cx, |thread, cx| {
687                        let pending_tool_use = thread
688                            .scripting_tool_use
689                            .insert_tool_output(tool_use_id.clone(), output);
690
691                        cx.emit(ThreadEvent::ToolFinished {
692                            tool_use_id,
693                            pending_tool_use,
694                        });
695                    })
696                    .ok();
697            }
698        });
699
700        self.scripting_tool_use
701            .run_pending_tool(tool_use_id, insert_output_task);
702    }
703
704    pub fn send_tool_results_to_model(
705        &mut self,
706        model: Arc<dyn LanguageModel>,
707        cx: &mut Context<Self>,
708    ) {
709        // Insert a user message to contain the tool results.
710        self.insert_user_message(
711            // TODO: Sending up a user message without any content results in the model sending back
712            // responses that also don't have any content. We currently don't handle this case well,
713            // so for now we provide some text to keep the model on track.
714            "Here are the tool results.",
715            Vec::new(),
716            cx,
717        );
718        self.send_to_model(model, RequestKind::Chat, true, cx);
719    }
720
721    /// Cancels the last pending completion, if there are any pending.
722    ///
723    /// Returns whether a completion was canceled.
724    pub fn cancel_last_completion(&mut self) -> bool {
725        if let Some(_last_completion) = self.pending_completions.pop() {
726            true
727        } else {
728            false
729        }
730    }
731}
732
733#[derive(Debug, Clone)]
734pub enum ThreadError {
735    PaymentRequired,
736    MaxMonthlySpendReached,
737    Message(SharedString),
738}
739
740#[derive(Debug, Clone)]
741pub enum ThreadEvent {
742    ShowError(ThreadError),
743    StreamedCompletion,
744    StreamedAssistantText(MessageId, String),
745    MessageAdded(MessageId),
746    MessageEdited(MessageId),
747    MessageDeleted(MessageId),
748    SummaryChanged,
749    UsePendingTools,
750    ToolFinished {
751        #[allow(unused)]
752        tool_use_id: LanguageModelToolUseId,
753        /// The pending tool use that corresponds to this tool.
754        pending_tool_use: Option<PendingToolUse>,
755    },
756}
757
758impl EventEmitter<ThreadEvent> for Thread {}
759
760struct PendingCompletion {
761    id: usize,
762    _task: Task<()>,
763}