thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use chrono::{DateTime, Utc};
  6use collections::{BTreeMap, HashMap, HashSet};
  7use futures::StreamExt as _;
  8use gpui::{App, Context, EventEmitter, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
 11    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolUseId,
 12    MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, Role, StopReason,
 13};
 14use serde::{Deserialize, Serialize};
 15use util::{post_inc, TryFutureExt as _};
 16use uuid::Uuid;
 17
 18use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 19use crate::thread_store::SavedThread;
 20use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
 21
 22#[derive(Debug, Clone, Copy)]
 23pub enum RequestKind {
 24    Chat,
 25    /// Used when summarizing a thread.
 26    Summarize,
 27}
 28
 29#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 30pub struct ThreadId(Arc<str>);
 31
 32impl ThreadId {
 33    pub fn new() -> Self {
 34        Self(Uuid::new_v4().to_string().into())
 35    }
 36}
 37
 38impl std::fmt::Display for ThreadId {
 39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 40        write!(f, "{}", self.0)
 41    }
 42}
 43
 44#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 45pub struct MessageId(pub(crate) usize);
 46
 47impl MessageId {
 48    fn post_inc(&mut self) -> Self {
 49        Self(post_inc(&mut self.0))
 50    }
 51}
 52
 53/// A message in a [`Thread`].
 54#[derive(Debug, Clone)]
 55pub struct Message {
 56    pub id: MessageId,
 57    pub role: Role,
 58    pub text: String,
 59}
 60
 61/// A thread of conversation with the LLM.
 62pub struct Thread {
 63    id: ThreadId,
 64    updated_at: DateTime<Utc>,
 65    summary: Option<SharedString>,
 66    pending_summary: Task<Option<()>>,
 67    messages: Vec<Message>,
 68    next_message_id: MessageId,
 69    context: BTreeMap<ContextId, ContextSnapshot>,
 70    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 71    completion_count: usize,
 72    pending_completions: Vec<PendingCompletion>,
 73    tools: Arc<ToolWorkingSet>,
 74    tool_use: ToolUseState,
 75}
 76
 77impl Thread {
 78    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
 79        Self {
 80            id: ThreadId::new(),
 81            updated_at: Utc::now(),
 82            summary: None,
 83            pending_summary: Task::ready(None),
 84            messages: Vec::new(),
 85            next_message_id: MessageId(0),
 86            context: BTreeMap::default(),
 87            context_by_message: HashMap::default(),
 88            completion_count: 0,
 89            pending_completions: Vec::new(),
 90            tools,
 91            tool_use: ToolUseState::default(),
 92        }
 93    }
 94
 95    pub fn from_saved(
 96        id: ThreadId,
 97        saved: SavedThread,
 98        tools: Arc<ToolWorkingSet>,
 99        _cx: &mut Context<Self>,
100    ) -> Self {
101        let next_message_id = MessageId(saved.messages.len());
102
103        Self {
104            id,
105            updated_at: saved.updated_at,
106            summary: Some(saved.summary),
107            pending_summary: Task::ready(None),
108            messages: saved
109                .messages
110                .into_iter()
111                .map(|message| Message {
112                    id: message.id,
113                    role: message.role,
114                    text: message.text,
115                })
116                .collect(),
117            next_message_id,
118            context: BTreeMap::default(),
119            context_by_message: HashMap::default(),
120            completion_count: 0,
121            pending_completions: Vec::new(),
122            tools,
123            tool_use: ToolUseState::default(),
124        }
125    }
126
127    pub fn id(&self) -> &ThreadId {
128        &self.id
129    }
130
131    pub fn is_empty(&self) -> bool {
132        self.messages.is_empty()
133    }
134
135    pub fn updated_at(&self) -> DateTime<Utc> {
136        self.updated_at
137    }
138
139    pub fn touch_updated_at(&mut self) {
140        self.updated_at = Utc::now();
141    }
142
143    pub fn summary(&self) -> Option<SharedString> {
144        self.summary.clone()
145    }
146
147    pub fn summary_or_default(&self) -> SharedString {
148        const DEFAULT: SharedString = SharedString::new_static("New Thread");
149        self.summary.clone().unwrap_or(DEFAULT)
150    }
151
152    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
153        self.summary = Some(summary.into());
154        cx.emit(ThreadEvent::SummaryChanged);
155    }
156
157    pub fn message(&self, id: MessageId) -> Option<&Message> {
158        self.messages.iter().find(|message| message.id == id)
159    }
160
161    pub fn messages(&self) -> impl Iterator<Item = &Message> {
162        self.messages.iter()
163    }
164
165    pub fn is_streaming(&self) -> bool {
166        !self.pending_completions.is_empty()
167    }
168
169    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
170        &self.tools
171    }
172
173    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
174        let context = self.context_by_message.get(&id)?;
175        Some(
176            context
177                .into_iter()
178                .filter_map(|context_id| self.context.get(&context_id))
179                .cloned()
180                .collect::<Vec<_>>(),
181        )
182    }
183
184    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
185        self.tool_use.pending_tool_uses()
186    }
187
188    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
189        self.tool_use.tool_uses_for_message(id)
190    }
191
192    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
193        self.tool_use.message_has_tool_results(message_id)
194    }
195
196    pub fn insert_user_message(
197        &mut self,
198        text: impl Into<String>,
199        context: Vec<ContextSnapshot>,
200        cx: &mut Context<Self>,
201    ) {
202        let message_id = self.insert_message(Role::User, text, cx);
203        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
204        self.context
205            .extend(context.into_iter().map(|context| (context.id, context)));
206        self.context_by_message.insert(message_id, context_ids);
207    }
208
209    pub fn insert_message(
210        &mut self,
211        role: Role,
212        text: impl Into<String>,
213        cx: &mut Context<Self>,
214    ) -> MessageId {
215        let id = self.next_message_id.post_inc();
216        self.messages.push(Message {
217            id,
218            role,
219            text: text.into(),
220        });
221        self.touch_updated_at();
222        cx.emit(ThreadEvent::MessageAdded(id));
223        id
224    }
225
226    /// Returns the representation of this [`Thread`] in a textual form.
227    ///
228    /// This is the representation we use when attaching a thread as context to another thread.
229    pub fn text(&self) -> String {
230        let mut text = String::new();
231
232        for message in &self.messages {
233            text.push_str(match message.role {
234                language_model::Role::User => "User:",
235                language_model::Role::Assistant => "Assistant:",
236                language_model::Role::System => "System:",
237            });
238            text.push('\n');
239
240            text.push_str(&message.text);
241            text.push('\n');
242        }
243
244        text
245    }
246
247    pub fn send_to_model(
248        &mut self,
249        model: Arc<dyn LanguageModel>,
250        request_kind: RequestKind,
251        use_tools: bool,
252        cx: &mut Context<Self>,
253    ) {
254        let mut request = self.to_completion_request(request_kind, cx);
255
256        if use_tools {
257            request.tools = self
258                .tools()
259                .tools(cx)
260                .into_iter()
261                .map(|tool| LanguageModelRequestTool {
262                    name: tool.name(),
263                    description: tool.description(),
264                    input_schema: tool.input_schema(),
265                })
266                .collect();
267        }
268
269        self.stream_completion(request, model, cx);
270    }
271
272    pub fn to_completion_request(
273        &self,
274        request_kind: RequestKind,
275        _cx: &App,
276    ) -> LanguageModelRequest {
277        let mut request = LanguageModelRequest {
278            messages: vec![],
279            tools: Vec::new(),
280            stop: Vec::new(),
281            temperature: None,
282        };
283
284        let mut referenced_context_ids = HashSet::default();
285
286        for message in &self.messages {
287            if let Some(context_ids) = self.context_by_message.get(&message.id) {
288                referenced_context_ids.extend(context_ids);
289            }
290
291            let mut request_message = LanguageModelRequestMessage {
292                role: message.role,
293                content: Vec::new(),
294                cache: false,
295            };
296            match request_kind {
297                RequestKind::Chat => {
298                    self.tool_use
299                        .attach_tool_results(message.id, &mut request_message);
300                }
301                RequestKind::Summarize => {
302                    // We don't care about tool use during summarization.
303                }
304            }
305
306            if !message.text.is_empty() {
307                request_message
308                    .content
309                    .push(MessageContent::Text(message.text.clone()));
310            }
311
312            match request_kind {
313                RequestKind::Chat => {
314                    self.tool_use
315                        .attach_tool_uses(message.id, &mut request_message);
316                }
317                RequestKind::Summarize => {
318                    // We don't care about tool use during summarization.
319                }
320            }
321
322            request.messages.push(request_message);
323        }
324
325        if !referenced_context_ids.is_empty() {
326            let mut context_message = LanguageModelRequestMessage {
327                role: Role::User,
328                content: Vec::new(),
329                cache: false,
330            };
331
332            let referenced_context = referenced_context_ids
333                .into_iter()
334                .filter_map(|context_id| self.context.get(context_id))
335                .cloned();
336            attach_context_to_message(&mut context_message, referenced_context);
337
338            request.messages.push(context_message);
339        }
340
341        request
342    }
343
344    pub fn stream_completion(
345        &mut self,
346        request: LanguageModelRequest,
347        model: Arc<dyn LanguageModel>,
348        cx: &mut Context<Self>,
349    ) {
350        let pending_completion_id = post_inc(&mut self.completion_count);
351
352        let task = cx.spawn(|thread, mut cx| async move {
353            let stream = model.stream_completion(request, &cx);
354            let stream_completion = async {
355                let mut events = stream.await?;
356                let mut stop_reason = StopReason::EndTurn;
357
358                while let Some(event) = events.next().await {
359                    let event = event?;
360
361                    thread.update(&mut cx, |thread, cx| {
362                        match event {
363                            LanguageModelCompletionEvent::StartMessage { .. } => {
364                                thread.insert_message(Role::Assistant, String::new(), cx);
365                            }
366                            LanguageModelCompletionEvent::Stop(reason) => {
367                                stop_reason = reason;
368                            }
369                            LanguageModelCompletionEvent::Text(chunk) => {
370                                if let Some(last_message) = thread.messages.last_mut() {
371                                    if last_message.role == Role::Assistant {
372                                        last_message.text.push_str(&chunk);
373                                        cx.emit(ThreadEvent::StreamedAssistantText(
374                                            last_message.id,
375                                            chunk,
376                                        ));
377                                    } else {
378                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
379                                        // of a new Assistant response.
380                                        //
381                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
382                                        // will result in duplicating the text of the chunk in the rendered Markdown.
383                                        thread.insert_message(Role::Assistant, chunk, cx);
384                                    }
385                                }
386                            }
387                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
388                                if let Some(last_assistant_message) = thread
389                                    .messages
390                                    .iter()
391                                    .rfind(|message| message.role == Role::Assistant)
392                                {
393                                    thread
394                                        .tool_use
395                                        .request_tool_use(last_assistant_message.id, tool_use);
396                                }
397                            }
398                        }
399
400                        thread.touch_updated_at();
401                        cx.emit(ThreadEvent::StreamedCompletion);
402                        cx.notify();
403                    })?;
404
405                    smol::future::yield_now().await;
406                }
407
408                thread.update(&mut cx, |thread, cx| {
409                    thread
410                        .pending_completions
411                        .retain(|completion| completion.id != pending_completion_id);
412
413                    if thread.summary.is_none() && thread.messages.len() >= 2 {
414                        thread.summarize(cx);
415                    }
416                })?;
417
418                anyhow::Ok(stop_reason)
419            };
420
421            let result = stream_completion.await;
422
423            thread
424                .update(&mut cx, |thread, cx| match result.as_ref() {
425                    Ok(stop_reason) => match stop_reason {
426                        StopReason::ToolUse => {
427                            cx.emit(ThreadEvent::UsePendingTools);
428                        }
429                        StopReason::EndTurn => {}
430                        StopReason::MaxTokens => {}
431                    },
432                    Err(error) => {
433                        if error.is::<PaymentRequiredError>() {
434                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
435                        } else if error.is::<MaxMonthlySpendReachedError>() {
436                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
437                        } else {
438                            let error_message = error
439                                .chain()
440                                .map(|err| err.to_string())
441                                .collect::<Vec<_>>()
442                                .join("\n");
443                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
444                                SharedString::from(error_message.clone()),
445                            )));
446                        }
447
448                        thread.cancel_last_completion();
449                    }
450                })
451                .ok();
452        });
453
454        self.pending_completions.push(PendingCompletion {
455            id: pending_completion_id,
456            _task: task,
457        });
458    }
459
460    pub fn summarize(&mut self, cx: &mut Context<Self>) {
461        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
462            return;
463        };
464        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
465            return;
466        };
467
468        if !provider.is_authenticated(cx) {
469            return;
470        }
471
472        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
473        request.messages.push(LanguageModelRequestMessage {
474            role: Role::User,
475            content: vec![
476                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
477                    .into(),
478            ],
479            cache: false,
480        });
481
482        self.pending_summary = cx.spawn(|this, mut cx| {
483            async move {
484                let stream = model.stream_completion_text(request, &cx);
485                let mut messages = stream.await?;
486
487                let mut new_summary = String::new();
488                while let Some(message) = messages.stream.next().await {
489                    let text = message?;
490                    let mut lines = text.lines();
491                    new_summary.extend(lines.next());
492
493                    // Stop if the LLM generated multiple lines.
494                    if lines.next().is_some() {
495                        break;
496                    }
497                }
498
499                this.update(&mut cx, |this, cx| {
500                    if !new_summary.is_empty() {
501                        this.summary = Some(new_summary.into());
502                    }
503
504                    cx.emit(ThreadEvent::SummaryChanged);
505                })?;
506
507                anyhow::Ok(())
508            }
509            .log_err()
510        });
511    }
512
513    pub fn insert_tool_output(
514        &mut self,
515        tool_use_id: LanguageModelToolUseId,
516        output: Task<Result<String>>,
517        cx: &mut Context<Self>,
518    ) {
519        let insert_output_task = cx.spawn(|thread, mut cx| {
520            let tool_use_id = tool_use_id.clone();
521            async move {
522                let output = output.await;
523                thread
524                    .update(&mut cx, |thread, cx| {
525                        thread
526                            .tool_use
527                            .insert_tool_output(tool_use_id.clone(), output);
528
529                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
530                    })
531                    .ok();
532            }
533        });
534
535        self.tool_use
536            .run_pending_tool(tool_use_id, insert_output_task);
537    }
538
539    /// Cancels the last pending completion, if there are any pending.
540    ///
541    /// Returns whether a completion was canceled.
542    pub fn cancel_last_completion(&mut self) -> bool {
543        if let Some(_last_completion) = self.pending_completions.pop() {
544            true
545        } else {
546            false
547        }
548    }
549}
550
551#[derive(Debug, Clone)]
552pub enum ThreadError {
553    PaymentRequired,
554    MaxMonthlySpendReached,
555    Message(SharedString),
556}
557
558#[derive(Debug, Clone)]
559pub enum ThreadEvent {
560    ShowError(ThreadError),
561    StreamedCompletion,
562    StreamedAssistantText(MessageId, String),
563    MessageAdded(MessageId),
564    SummaryChanged,
565    UsePendingTools,
566    ToolFinished {
567        #[allow(unused)]
568        tool_use_id: LanguageModelToolUseId,
569    },
570}
571
572impl EventEmitter<ThreadEvent> for Thread {}
573
574struct PendingCompletion {
575    id: usize,
576    _task: Task<()>,
577}