thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use chrono::{DateTime, Utc};
  6use collections::{BTreeMap, HashMap, HashSet};
  7use futures::StreamExt as _;
  8use gpui::{App, Context, Entity, EventEmitter, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest,
 11    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
 12    LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError,
 13    Role, StopReason,
 14};
 15use project::Project;
 16use serde::{Deserialize, Serialize};
 17use util::{post_inc, TryFutureExt as _};
 18use uuid::Uuid;
 19
 20use crate::context::{attach_context_to_message, ContextId, ContextSnapshot};
 21use crate::thread_store::SavedThread;
 22use crate::tool_use::{PendingToolUse, ToolUse, ToolUseState};
 23
 24#[derive(Debug, Clone, Copy)]
 25pub enum RequestKind {
 26    Chat,
 27    /// Used when summarizing a thread.
 28    Summarize,
 29}
 30
 31#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)]
 32pub struct ThreadId(Arc<str>);
 33
 34impl ThreadId {
 35    pub fn new() -> Self {
 36        Self(Uuid::new_v4().to_string().into())
 37    }
 38}
 39
 40impl std::fmt::Display for ThreadId {
 41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
 42        write!(f, "{}", self.0)
 43    }
 44}
 45
 46#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 47pub struct MessageId(pub(crate) usize);
 48
 49impl MessageId {
 50    fn post_inc(&mut self) -> Self {
 51        Self(post_inc(&mut self.0))
 52    }
 53}
 54
 55/// A message in a [`Thread`].
 56#[derive(Debug, Clone)]
 57pub struct Message {
 58    pub id: MessageId,
 59    pub role: Role,
 60    pub text: String,
 61}
 62
 63/// A thread of conversation with the LLM.
 64pub struct Thread {
 65    id: ThreadId,
 66    updated_at: DateTime<Utc>,
 67    summary: Option<SharedString>,
 68    pending_summary: Task<Option<()>>,
 69    messages: Vec<Message>,
 70    next_message_id: MessageId,
 71    context: BTreeMap<ContextId, ContextSnapshot>,
 72    context_by_message: HashMap<MessageId, Vec<ContextId>>,
 73    completion_count: usize,
 74    pending_completions: Vec<PendingCompletion>,
 75    project: Entity<Project>,
 76    tools: Arc<ToolWorkingSet>,
 77    tool_use: ToolUseState,
 78}
 79
 80impl Thread {
 81    pub fn new(
 82        project: Entity<Project>,
 83        tools: Arc<ToolWorkingSet>,
 84        _cx: &mut Context<Self>,
 85    ) -> Self {
 86        Self {
 87            id: ThreadId::new(),
 88            updated_at: Utc::now(),
 89            summary: None,
 90            pending_summary: Task::ready(None),
 91            messages: Vec::new(),
 92            next_message_id: MessageId(0),
 93            context: BTreeMap::default(),
 94            context_by_message: HashMap::default(),
 95            completion_count: 0,
 96            pending_completions: Vec::new(),
 97            project,
 98            tools,
 99            tool_use: ToolUseState::new(),
100        }
101    }
102
103    pub fn from_saved(
104        id: ThreadId,
105        saved: SavedThread,
106        project: Entity<Project>,
107        tools: Arc<ToolWorkingSet>,
108        _cx: &mut Context<Self>,
109    ) -> Self {
110        let next_message_id = MessageId(
111            saved
112                .messages
113                .last()
114                .map(|message| message.id.0 + 1)
115                .unwrap_or(0),
116        );
117        let tool_use = ToolUseState::from_saved_messages(&saved.messages);
118
119        Self {
120            id,
121            updated_at: saved.updated_at,
122            summary: Some(saved.summary),
123            pending_summary: Task::ready(None),
124            messages: saved
125                .messages
126                .into_iter()
127                .map(|message| Message {
128                    id: message.id,
129                    role: message.role,
130                    text: message.text,
131                })
132                .collect(),
133            next_message_id,
134            context: BTreeMap::default(),
135            context_by_message: HashMap::default(),
136            completion_count: 0,
137            pending_completions: Vec::new(),
138            project,
139            tools,
140            tool_use,
141        }
142    }
143
144    pub fn id(&self) -> &ThreadId {
145        &self.id
146    }
147
148    pub fn is_empty(&self) -> bool {
149        self.messages.is_empty()
150    }
151
152    pub fn updated_at(&self) -> DateTime<Utc> {
153        self.updated_at
154    }
155
156    pub fn touch_updated_at(&mut self) {
157        self.updated_at = Utc::now();
158    }
159
160    pub fn summary(&self) -> Option<SharedString> {
161        self.summary.clone()
162    }
163
164    pub fn summary_or_default(&self) -> SharedString {
165        const DEFAULT: SharedString = SharedString::new_static("New Thread");
166        self.summary.clone().unwrap_or(DEFAULT)
167    }
168
169    pub fn set_summary(&mut self, summary: impl Into<SharedString>, cx: &mut Context<Self>) {
170        self.summary = Some(summary.into());
171        cx.emit(ThreadEvent::SummaryChanged);
172    }
173
174    pub fn message(&self, id: MessageId) -> Option<&Message> {
175        self.messages.iter().find(|message| message.id == id)
176    }
177
178    pub fn messages(&self) -> impl Iterator<Item = &Message> {
179        self.messages.iter()
180    }
181
182    pub fn is_streaming(&self) -> bool {
183        !self.pending_completions.is_empty()
184    }
185
186    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
187        &self.tools
188    }
189
190    pub fn context_for_message(&self, id: MessageId) -> Option<Vec<ContextSnapshot>> {
191        let context = self.context_by_message.get(&id)?;
192        Some(
193            context
194                .into_iter()
195                .filter_map(|context_id| self.context.get(&context_id))
196                .cloned()
197                .collect::<Vec<_>>(),
198        )
199    }
200
201    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
202        self.tool_use.pending_tool_uses()
203    }
204
205    /// Returns whether all of the tool uses have finished running.
206    pub fn all_tools_finished(&self) -> bool {
207        // If the only pending tool uses left are the ones with errors, then that means that we've finished running all
208        // of the pending tools.
209        self.pending_tool_uses()
210            .into_iter()
211            .all(|tool_use| tool_use.status.is_error())
212    }
213
214    pub fn tool_uses_for_message(&self, id: MessageId) -> Vec<ToolUse> {
215        self.tool_use.tool_uses_for_message(id)
216    }
217
218    pub fn tool_results_for_message(&self, id: MessageId) -> Vec<&LanguageModelToolResult> {
219        self.tool_use.tool_results_for_message(id)
220    }
221
222    pub fn message_has_tool_results(&self, message_id: MessageId) -> bool {
223        self.tool_use.message_has_tool_results(message_id)
224    }
225
226    pub fn insert_user_message(
227        &mut self,
228        text: impl Into<String>,
229        context: Vec<ContextSnapshot>,
230        cx: &mut Context<Self>,
231    ) {
232        let message_id = self.insert_message(Role::User, text, cx);
233        let context_ids = context.iter().map(|context| context.id).collect::<Vec<_>>();
234        self.context
235            .extend(context.into_iter().map(|context| (context.id, context)));
236        self.context_by_message.insert(message_id, context_ids);
237    }
238
239    pub fn insert_message(
240        &mut self,
241        role: Role,
242        text: impl Into<String>,
243        cx: &mut Context<Self>,
244    ) -> MessageId {
245        let id = self.next_message_id.post_inc();
246        self.messages.push(Message {
247            id,
248            role,
249            text: text.into(),
250        });
251        self.touch_updated_at();
252        cx.emit(ThreadEvent::MessageAdded(id));
253        id
254    }
255
256    pub fn edit_message(
257        &mut self,
258        id: MessageId,
259        new_role: Role,
260        new_text: String,
261        cx: &mut Context<Self>,
262    ) -> bool {
263        let Some(message) = self.messages.iter_mut().find(|message| message.id == id) else {
264            return false;
265        };
266        message.role = new_role;
267        message.text = new_text;
268        self.touch_updated_at();
269        cx.emit(ThreadEvent::MessageEdited(id));
270        true
271    }
272
273    pub fn delete_message(&mut self, id: MessageId, cx: &mut Context<Self>) -> bool {
274        let Some(index) = self.messages.iter().position(|message| message.id == id) else {
275            return false;
276        };
277        self.messages.remove(index);
278        self.context_by_message.remove(&id);
279        self.touch_updated_at();
280        cx.emit(ThreadEvent::MessageDeleted(id));
281        true
282    }
283
284    /// Returns the representation of this [`Thread`] in a textual form.
285    ///
286    /// This is the representation we use when attaching a thread as context to another thread.
287    pub fn text(&self) -> String {
288        let mut text = String::new();
289
290        for message in &self.messages {
291            text.push_str(match message.role {
292                language_model::Role::User => "User:",
293                language_model::Role::Assistant => "Assistant:",
294                language_model::Role::System => "System:",
295            });
296            text.push('\n');
297
298            text.push_str(&message.text);
299            text.push('\n');
300        }
301
302        text
303    }
304
305    pub fn send_to_model(
306        &mut self,
307        model: Arc<dyn LanguageModel>,
308        request_kind: RequestKind,
309        use_tools: bool,
310        cx: &mut Context<Self>,
311    ) {
312        let mut request = self.to_completion_request(request_kind, cx);
313
314        if use_tools {
315            request.tools = self
316                .tools()
317                .tools(cx)
318                .into_iter()
319                .map(|tool| LanguageModelRequestTool {
320                    name: tool.name(),
321                    description: tool.description(),
322                    input_schema: tool.input_schema(),
323                })
324                .collect();
325        }
326
327        self.stream_completion(request, model, cx);
328    }
329
330    pub fn to_completion_request(
331        &self,
332        request_kind: RequestKind,
333        _cx: &App,
334    ) -> LanguageModelRequest {
335        let mut request = LanguageModelRequest {
336            messages: vec![],
337            tools: Vec::new(),
338            stop: Vec::new(),
339            temperature: None,
340        };
341
342        let mut referenced_context_ids = HashSet::default();
343
344        for message in &self.messages {
345            if let Some(context_ids) = self.context_by_message.get(&message.id) {
346                referenced_context_ids.extend(context_ids);
347            }
348
349            let mut request_message = LanguageModelRequestMessage {
350                role: message.role,
351                content: Vec::new(),
352                cache: false,
353            };
354            match request_kind {
355                RequestKind::Chat => {
356                    self.tool_use
357                        .attach_tool_results(message.id, &mut request_message);
358                }
359                RequestKind::Summarize => {
360                    // We don't care about tool use during summarization.
361                }
362            }
363
364            if !message.text.is_empty() {
365                request_message
366                    .content
367                    .push(MessageContent::Text(message.text.clone()));
368            }
369
370            match request_kind {
371                RequestKind::Chat => {
372                    self.tool_use
373                        .attach_tool_uses(message.id, &mut request_message);
374                }
375                RequestKind::Summarize => {
376                    // We don't care about tool use during summarization.
377                }
378            }
379
380            request.messages.push(request_message);
381        }
382
383        if !referenced_context_ids.is_empty() {
384            let mut context_message = LanguageModelRequestMessage {
385                role: Role::User,
386                content: Vec::new(),
387                cache: false,
388            };
389
390            let referenced_context = referenced_context_ids
391                .into_iter()
392                .filter_map(|context_id| self.context.get(context_id))
393                .cloned();
394            attach_context_to_message(&mut context_message, referenced_context);
395
396            request.messages.push(context_message);
397        }
398
399        request
400    }
401
402    pub fn stream_completion(
403        &mut self,
404        request: LanguageModelRequest,
405        model: Arc<dyn LanguageModel>,
406        cx: &mut Context<Self>,
407    ) {
408        let pending_completion_id = post_inc(&mut self.completion_count);
409
410        let task = cx.spawn(|thread, mut cx| async move {
411            let stream = model.stream_completion(request, &cx);
412            let stream_completion = async {
413                let mut events = stream.await?;
414                let mut stop_reason = StopReason::EndTurn;
415
416                while let Some(event) = events.next().await {
417                    let event = event?;
418
419                    thread.update(&mut cx, |thread, cx| {
420                        match event {
421                            LanguageModelCompletionEvent::StartMessage { .. } => {
422                                thread.insert_message(Role::Assistant, String::new(), cx);
423                            }
424                            LanguageModelCompletionEvent::Stop(reason) => {
425                                stop_reason = reason;
426                            }
427                            LanguageModelCompletionEvent::Text(chunk) => {
428                                if let Some(last_message) = thread.messages.last_mut() {
429                                    if last_message.role == Role::Assistant {
430                                        last_message.text.push_str(&chunk);
431                                        cx.emit(ThreadEvent::StreamedAssistantText(
432                                            last_message.id,
433                                            chunk,
434                                        ));
435                                    } else {
436                                        // If we won't have an Assistant message yet, assume this chunk marks the beginning
437                                        // of a new Assistant response.
438                                        //
439                                        // Importantly: We do *not* want to emit a `StreamedAssistantText` event here, as it
440                                        // will result in duplicating the text of the chunk in the rendered Markdown.
441                                        thread.insert_message(Role::Assistant, chunk, cx);
442                                    }
443                                }
444                            }
445                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
446                                if let Some(last_assistant_message) = thread
447                                    .messages
448                                    .iter()
449                                    .rfind(|message| message.role == Role::Assistant)
450                                {
451                                    thread
452                                        .tool_use
453                                        .request_tool_use(last_assistant_message.id, tool_use);
454                                }
455                            }
456                        }
457
458                        thread.touch_updated_at();
459                        cx.emit(ThreadEvent::StreamedCompletion);
460                        cx.notify();
461                    })?;
462
463                    smol::future::yield_now().await;
464                }
465
466                thread.update(&mut cx, |thread, cx| {
467                    thread
468                        .pending_completions
469                        .retain(|completion| completion.id != pending_completion_id);
470
471                    if thread.summary.is_none() && thread.messages.len() >= 2 {
472                        thread.summarize(cx);
473                    }
474                })?;
475
476                anyhow::Ok(stop_reason)
477            };
478
479            let result = stream_completion.await;
480
481            thread
482                .update(&mut cx, |thread, cx| match result.as_ref() {
483                    Ok(stop_reason) => match stop_reason {
484                        StopReason::ToolUse => {
485                            cx.emit(ThreadEvent::UsePendingTools);
486                        }
487                        StopReason::EndTurn => {}
488                        StopReason::MaxTokens => {}
489                    },
490                    Err(error) => {
491                        if error.is::<PaymentRequiredError>() {
492                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
493                        } else if error.is::<MaxMonthlySpendReachedError>() {
494                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
495                        } else {
496                            let error_message = error
497                                .chain()
498                                .map(|err| err.to_string())
499                                .collect::<Vec<_>>()
500                                .join("\n");
501                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
502                                SharedString::from(error_message.clone()),
503                            )));
504                        }
505
506                        thread.cancel_last_completion();
507                    }
508                })
509                .ok();
510        });
511
512        self.pending_completions.push(PendingCompletion {
513            id: pending_completion_id,
514            _task: task,
515        });
516    }
517
518    pub fn summarize(&mut self, cx: &mut Context<Self>) {
519        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
520            return;
521        };
522        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
523            return;
524        };
525
526        if !provider.is_authenticated(cx) {
527            return;
528        }
529
530        let mut request = self.to_completion_request(RequestKind::Summarize, cx);
531        request.messages.push(LanguageModelRequestMessage {
532            role: Role::User,
533            content: vec![
534                "Generate a concise 3-7 word title for this conversation, omitting punctuation. Go straight to the title, without any preamble and prefix like `Here's a concise suggestion:...` or `Title:`"
535                    .into(),
536            ],
537            cache: false,
538        });
539
540        self.pending_summary = cx.spawn(|this, mut cx| {
541            async move {
542                let stream = model.stream_completion_text(request, &cx);
543                let mut messages = stream.await?;
544
545                let mut new_summary = String::new();
546                while let Some(message) = messages.stream.next().await {
547                    let text = message?;
548                    let mut lines = text.lines();
549                    new_summary.extend(lines.next());
550
551                    // Stop if the LLM generated multiple lines.
552                    if lines.next().is_some() {
553                        break;
554                    }
555                }
556
557                this.update(&mut cx, |this, cx| {
558                    if !new_summary.is_empty() {
559                        this.summary = Some(new_summary.into());
560                    }
561
562                    cx.emit(ThreadEvent::SummaryChanged);
563                })?;
564
565                anyhow::Ok(())
566            }
567            .log_err()
568        });
569    }
570
571    pub fn use_pending_tools(&mut self, cx: &mut Context<Self>) {
572        let pending_tool_uses = self
573            .pending_tool_uses()
574            .into_iter()
575            .filter(|tool_use| tool_use.status.is_idle())
576            .cloned()
577            .collect::<Vec<_>>();
578
579        for tool_use in pending_tool_uses {
580            if let Some(tool) = self.tools.tool(&tool_use.name, cx) {
581                let task = tool.run(tool_use.input, self.project.clone(), cx);
582
583                self.insert_tool_output(tool_use.id.clone(), task, cx);
584            }
585        }
586    }
587
588    pub fn insert_tool_output(
589        &mut self,
590        tool_use_id: LanguageModelToolUseId,
591        output: Task<Result<String>>,
592        cx: &mut Context<Self>,
593    ) {
594        let insert_output_task = cx.spawn(|thread, mut cx| {
595            let tool_use_id = tool_use_id.clone();
596            async move {
597                let output = output.await;
598                thread
599                    .update(&mut cx, |thread, cx| {
600                        thread
601                            .tool_use
602                            .insert_tool_output(tool_use_id.clone(), output);
603
604                        cx.emit(ThreadEvent::ToolFinished { tool_use_id });
605                    })
606                    .ok();
607            }
608        });
609
610        self.tool_use
611            .run_pending_tool(tool_use_id, insert_output_task);
612    }
613
614    pub fn send_tool_results_to_model(
615        &mut self,
616        model: Arc<dyn LanguageModel>,
617        cx: &mut Context<Self>,
618    ) {
619        // Insert a user message to contain the tool results.
620        self.insert_user_message(
621            // TODO: Sending up a user message without any content results in the model sending back
622            // responses that also don't have any content. We currently don't handle this case well,
623            // so for now we provide some text to keep the model on track.
624            "Here are the tool results.",
625            Vec::new(),
626            cx,
627        );
628        self.send_to_model(model, RequestKind::Chat, true, cx);
629    }
630
631    /// Cancels the last pending completion, if there are any pending.
632    ///
633    /// Returns whether a completion was canceled.
634    pub fn cancel_last_completion(&mut self) -> bool {
635        if let Some(_last_completion) = self.pending_completions.pop() {
636            true
637        } else {
638            false
639        }
640    }
641}
642
643#[derive(Debug, Clone)]
644pub enum ThreadError {
645    PaymentRequired,
646    MaxMonthlySpendReached,
647    Message(SharedString),
648}
649
650#[derive(Debug, Clone)]
651pub enum ThreadEvent {
652    ShowError(ThreadError),
653    StreamedCompletion,
654    StreamedAssistantText(MessageId, String),
655    MessageAdded(MessageId),
656    MessageEdited(MessageId),
657    MessageDeleted(MessageId),
658    SummaryChanged,
659    UsePendingTools,
660    ToolFinished {
661        #[allow(unused)]
662        tool_use_id: LanguageModelToolUseId,
663    },
664}
665
666impl EventEmitter<ThreadEvent> for Thread {}
667
668struct PendingCompletion {
669    id: usize,
670    _task: Task<()>,
671}