thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use chrono::{DateTime, Utc};
  6use collections::{BTreeMap, HashMap, HashSet};
  7use futures::StreamExt as _;
  8use gpui::{App, Context, EventEmitter, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
 11    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
 12    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
 13    Role, StopReason,
 14};
 15use serde::{Deserialize, Serialize};
 16use util::{post_inc, TryFutureExt as _};
 17use uuid::Uuid;
 18
 19use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 20use crate::thread_store::SavedThread;
 21use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
 22
 23#[derive(Debug, Clone, Copy)]
 24pub enum RequestKind {
 25    Chat,
 26    /// Used when summarizing a thread.
 27    Summarize,
 28}
 29
 30#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 31pub struct ThreadId(Arc<str>);
 32
 33impl ThreadId {
 34    pub fn new() -> Self {
 35        Self(Uuid::new_v4().to_string().into())
 36    }
 37}
 38
 39impl std::fmt::Display for ThreadId {
 40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 41        write!(f, "{}", self.0)
 42    }
 43}
 44
 45#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 46pub struct MessageId(pub(crate) usize);
 47
 48impl MessageId {
 49    fn post_inc(&mut self) -> Self {
 50        Self(post_inc(&mut self.0))
 51    }
 52}
 53
 54/// A message in a [`Thread`].
 55#[derive(Debug, Clone)]
 56pub struct Message {
 57    pub id: MessageId,
 58    pub role: Role,
 59    pub text: String,
 60}
 61
 62/// A thread of conversation with the LLM.
 63pub struct Thread {
 64    id: ThreadId,
 65    updated_at: DateTime<Utc>,
 66    summary: Option<SharedString>,
 67    pending_summary: Task<Option<()>>,
 68    messages: Vec<Message>,
 69    next_message_id: MessageId,
 70    context: BTreeMap<ContextId, ContextSnapshot>,
 71    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 72    completion_count: usize,
 73    pending_completions: Vec<PendingCompletion>,
 74    tools: Arc<ToolWorkingSet>,
 75    tool_use: ToolUseState,
 76}
 77
 78impl Thread {
 79    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut Context<Self>) -> Self {
 80        Self {
 81            id: ThreadId::new(),
 82            updated_at: Utc::now(),
 83            summary: None,
 84            pending_summary: Task::ready(None),
 85            messages: Vec::new(),
 86            next_message_id: MessageId(0),
 87            context: BTreeMap::default(),
 88            context_by_message: HashMap::default(),
 89            completion_count: 0,
 90            pending_completions: Vec::new(),
 91            tools,
 92            tool_use: ToolUseState::new(),
 93        }
 94    }
 95
 96    pub fn from_saved(
 97        id: ThreadId,
 98        saved: SavedThread,
 99        tools: Arc<ToolWorkingSet>,
100        _cx: &mut Context<Self>,
101    ) -> Self {
102        let next_message_id = MessageId(saved.messages.len());
103        let tool_use = ToolUseState::from_saved_messages(&saved.messages);
104
105        Self {
106            id,
107            updated_at: saved.updated_at,
108            summary: Some(saved.summary),
109            pending_summary: Task::ready(None),
110            messages: saved
111                .messages
112                .into_iter()
113                .map(|message| Message {
114                    id: message.id,
115                    role: message.role,
116                    text: message.text,
117                })
118                .collect(),
119            next_message_id,
120            context: BTreeMap::default(),
121            context_by_message: HashMap::default(),
122            completion_count: 0,
123            pending_completions: Vec::new(),
124            tools,
125            tool_use,
126        }
127    }
128
129    pub fn id(&self) -> &ThreadId {
130        &self.id
131    }
132
133    pub fn is_empty(&self) -> bool {
134        self.messages.is_empty()
135    }
136
137    pub fn updated_at(&self) -> DateTime<Utc> {
138        self.updated_at
139    }
140
141    pub fn touch_updated_at(&mut self) {
142        self.updated_at = Utc::now();
143    }
144
145    pub fn summary(&self) -> Option<SharedString> {
146        self.summary.clone()
147    }
148
149    pub fn summary_or_default(&self) -> SharedString {
150        const DEFAULT: SharedString = SharedString::new_static("New Thread");
151        self.summary.clone().unwrap_or(DEFAULT)
152    }
153
154    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
155        self.summary = Some(summary.into());
156        cx.emit(ThreadEvent::SummaryChanged);
157    }
158
159    pub fn message(&self, id: MessageId) -> Option<&Message> {
160        self.messages.iter().find(|message| message.id == id)
161    }
162
163    pub fn messages(&self) -> impl Iterator<Item = &Message> {
164        self.messages.iter()
165    }
166
167    pub fn is_streaming(&self) -> bool {
168        !self.pending_completions.is_empty()
169    }
170
171    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
172        &self.tools
173    }
174
175    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
176        let context = self.context_by_message.get(&id)?;
177        Some(
178            context
179                .into_iter()
180                .filter_map(|context_id| self.context.get(&context_id))
181                .cloned()
182                .collect::<Vec<_>>(),
183        )
184    }
185
186    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
187        self.tool_use.pending_tool_uses()
188    }
189
190    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
191        self.tool_use.tool_uses_for_message(id)
192    }
193
194    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
195        self.tool_use.tool_results_for_message(id)
196    }
197
198    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
199        self.tool_use.message_has_tool_results(message_id)
200    }
201
202    pub fn insert_user_message(
203        &mut self,
204        text: impl Into<String>,
205        context: Vec<ContextSnapshot>,
206        cx: &mut Context<Self>,
207    ) {
208        let message_id = self.insert_message(Role::User, text, cx);
209        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
210        self.context
211            .extend(context.into_iter().map(|context| (context.id, context)));
212        self.context_by_message.insert(message_id, context_ids);
213    }
214
215    pub fn insert_message(
216        &mut self,
217        role: Role,
218        text: impl Into<String>,
219        cx: &mut Context<Self>,
220    ) -> MessageId {
221        let id = self.next_message_id.post_inc();
222        self.messages.push(Message {
223            id,
224            role,
225            text: text.into(),
226        });
227        self.touch_updated_at();
228        cx.emit(ThreadEvent::MessageAdded(id));
229        id
230    }
231
232    /// Returns the representation of this [`Thread`] in a textual form.
233    ///
234    /// This is the representation we use when attaching a thread as context to another thread.
235    pub fn text(&self) -> String {
236        let mut text = String::new();
237
238        for message in &self.messages {
239            text.push_str(match message.role {
240                language_model::Role::User => "User:",
241                language_model::Role::Assistant => "Assistant:",
242                language_model::Role::System => "System:",
243            });
244            text.push('\n');
245
246            text.push_str(&message.text);
247            text.push('\n');
248        }
249
250        text
251    }
252
253    pub fn send_to_model(
254        &mut self,
255        model: Arc<dyn LanguageModel>,
256        request_kind: RequestKind,
257        use_tools: bool,
258        cx: &mut Context<Self>,
259    ) {
260        let mut request = self.to_completion_request(request_kind, cx);
261
262        if use_tools {
263            request.tools = self
264                .tools()
265                .tools(cx)
266                .into_iter()
267                .map(|tool| LanguageModelRequestTool {
268                    name: tool.name(),
269                    description: tool.description(),
270                    input_schema: tool.input_schema(),
271                })
272                .collect();
273        }
274
275        self.stream_completion(request, model, cx);
276    }
277
278    pub fn to_completion_request(
279        &self,
280        request_kind: RequestKind,
281        _cx: &App,
282    ) -> LanguageModelRequest {
283        let mut request = LanguageModelRequest {
284            messages: vec![],
285            tools: Vec::new(),
286            stop: Vec::new(),
287            temperature: None,
288        };
289
290        let mut referenced_context_ids = HashSet::default();
291
292        for message in &self.messages {
293            if let Some(context_ids) = self.context_by_message.get(&message.id) {
294                referenced_context_ids.extend(context_ids);
295            }
296
297            let mut request_message = LanguageModelRequestMessage {
298                role: message.role,
299                content: Vec::new(),
300                cache: false,
301            };
302            match request_kind {
303                RequestKind::Chat => {
304                    self.tool_use
305                        .attach_tool_results(message.id, &mut request_message);
306                }
307                RequestKind::Summarize => {
308                    // We don't care about tool use during summarization.
309                }
310            }
311
312            if !message.text.is_empty() {
313                request_message
314                    .content
315                    .push(MessageContent::Text(message.text.clone()));
316            }
317
318            match request_kind {
319                RequestKind::Chat => {
320                    self.tool_use
321                        .attach_tool_uses(message.id, &mut request_message);
322                }
323                RequestKind::Summarize => {
324                    // We don't care about tool use during summarization.
325                }
326            }
327
328            request.messages.push(request_message);
329        }
330
331        if !referenced_context_ids.is_empty() {
332            let mut context_message = LanguageModelRequestMessage {
333                role: Role::User,
334                content: Vec::new(),
335                cache: false,
336            };
337
338            let referenced_context = referenced_context_ids
339                .into_iter()
340                .filter_map(|context_id| self.context.get(context_id))
341                .cloned();
342            attach_context_to_message(&mut context_message, referenced_context);
343
344            request.messages.push(context_message);
345        }
346
347        request
348    }
349
350    pub fn stream_completion(
351        &mut self,
352        request: LanguageModelRequest,
353        model: Arc<dyn LanguageModel>,
354        cx: &mut Context<Self>,
355    ) {
356        let pending_completion_id = post_inc(&mut self.completion_count);
357
358        let task = cx.spawn(|thread, mut cx| async move {
359            let stream = model.stream_completion(request, &cx);
360            let stream_completion = async {
361                let mut events = stream.await?;
362                let mut stop_reason = StopReason::EndTurn;
363
364                while let Some(event) = events.next().await {
365                    let event = event?;
366
367                    thread.update(&mut cx, |thread, cx| {
368                        match event {
369                            LanguageModelCompletionEvent::StartMessage { .. } => {
370                                thread.insert_message(Role::Assistant, String::new(), cx);
371                            }
372                            LanguageModelCompletionEvent::Stop(reason) => {
373                                stop_reason = reason;
374                            }
375                            LanguageModelCompletionEvent::Text(chunk) => {
376                                if let Some(last_message) = thread.messages.last_mut() {
377                                    if last_message.role == Role::Assistant {
378                                        last_message.text.push_str(&chunk);
379                                        cx.emit(ThreadEvent::StreamedAssistantText(
380                                            last_message.id,
381                                            chunk,
382                                        ));
383                                    } else {
384                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
385                                        // of a new Assistant response.
386                                        //
387                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
388                                        // will result in duplicating the text of the chunk in the rendered Markdown.
389                                        thread.insert_message(Role::Assistant, chunk, cx);
390                                    }
391                                }
392                            }
393                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
394                                if let Some(last_assistant_message) = thread
395                                    .messages
396                                    .iter()
397                                    .rfind(|message| message.role == Role::Assistant)
398                                {
399                                    thread
400                                        .tool_use
401                                        .request_tool_use(last_assistant_message.id, tool_use);
402                                }
403                            }
404                        }
405
406                        thread.touch_updated_at();
407                        cx.emit(ThreadEvent::StreamedCompletion);
408                        cx.notify();
409                    })?;
410
411                    smol::future::yield_now().await;
412                }
413
414                thread.update(&mut cx, |thread, cx| {
415                    thread
416                        .pending_completions
417                        .retain(|completion| completion.id != pending_completion_id);
418
419                    if thread.summary.is_none() && thread.messages.len() >= 2 {
420                        thread.summarize(cx);
421                    }
422                })?;
423
424                anyhow::Ok(stop_reason)
425            };
426
427            let result = stream_completion.await;
428
429            thread
430                .update(&mut cx, |thread, cx| match result.as_ref() {
431                    Ok(stop_reason) => match stop_reason {
432                        StopReason::ToolUse => {
433                            cx.emit(ThreadEvent::UsePendingTools);
434                        }
435                        StopReason::EndTurn => {}
436                        StopReason::MaxTokens => {}
437                    },
438                    Err(error) => {
439                        if error.is::<PaymentRequiredError>() {
440                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
441                        } else if error.is::<MaxMonthlySpendReachedError>() {
442                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
443                        } else {
444                            let error_message = error
445                                .chain()
446                                .map(|err| err.to_string())
447                                .collect::<Vec<_>>()
448                                .join("\n");
449                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
450                                SharedString::from(error_message.clone()),
451                            )));
452                        }
453
454                        thread.cancel_last_completion();
455                    }
456                })
457                .ok();
458        });
459
460        self.pending_completions.push(PendingCompletion {
461            id: pending_completion_id,
462            _task: task,
463        });
464    }
465
466    pub fn summarize(&mut self, cx: &mut Context<Self>) {
467        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
468            return;
469        };
470        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
471            return;
472        };
473
474        if !provider.is_authenticated(cx) {
475            return;
476        }
477
478        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
479        request.messages.push(LanguageModelRequestMessage {
480            role: Role::User,
481            content: vec![
482                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
483                    .into(),
484            ],
485            cache: false,
486        });
487
488        self.pending_summary = cx.spawn(|this, mut cx| {
489            async move {
490                let stream = model.stream_completion_text(request, &cx);
491                let mut messages = stream.await?;
492
493                let mut new_summary = String::new();
494                while let Some(message) = messages.stream.next().await {
495                    let text = message?;
496                    let mut lines = text.lines();
497                    new_summary.extend(lines.next());
498
499                    // Stop if the LLM generated multiple lines.
500                    if lines.next().is_some() {
501                        break;
502                    }
503                }
504
505                this.update(&mut cx, |this, cx| {
506                    if !new_summary.is_empty() {
507                        this.summary = Some(new_summary.into());
508                    }
509
510                    cx.emit(ThreadEvent::SummaryChanged);
511                })?;
512
513                anyhow::Ok(())
514            }
515            .log_err()
516        });
517    }
518
519    pub fn insert_tool_output(
520        &mut self,
521        tool_use_id: LanguageModelToolUseId,
522        output: Task<Result<String>>,
523        cx: &mut Context<Self>,
524    ) {
525        let insert_output_task = cx.spawn(|thread, mut cx| {
526            let tool_use_id = tool_use_id.clone();
527            async move {
528                let output = output.await;
529                thread
530                    .update(&mut cx, |thread, cx| {
531                        thread
532                            .tool_use
533                            .insert_tool_output(tool_use_id.clone(), output);
534
535                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
536                    })
537                    .ok();
538            }
539        });
540
541        self.tool_use
542            .run_pending_tool(tool_use_id, insert_output_task);
543    }
544
545    /// Cancels the last pending completion, if there are any pending.
546    ///
547    /// Returns whether a completion was canceled.
548    pub fn cancel_last_completion(&mut self) -> bool {
549        if let Some(_last_completion) = self.pending_completions.pop() {
550            true
551        } else {
552            false
553        }
554    }
555}
556
557#[derive(Debug, Clone)]
558pub enum ThreadError {
559    PaymentRequired,
560    MaxMonthlySpendReached,
561    Message(SharedString),
562}
563
564#[derive(Debug, Clone)]
565pub enum ThreadEvent {
566    ShowError(ThreadError),
567    StreamedCompletion,
568    StreamedAssistantText(MessageId, String),
569    MessageAdded(MessageId),
570    SummaryChanged,
571    UsePendingTools,
572    ToolFinished {
573        #[allow(unused)]
574        tool_use_id: LanguageModelToolUseId,
575    },
576}
577
578impl EventEmitter<ThreadEvent> for Thread {}
579
580struct PendingCompletion {
581    id: usize,
582    _task: Task<()>,
583}