thread.rs

  1use std::sync::Arc;
  2
  3use futures::StreamExt as _;
  4use gpui::{AppContext, EventEmitter, ModelContext, Task};
  5use language_model::{
  6    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
  7    MessageContent, Role, StopReason,
  8};
  9use util::{post_inc, ResultExt as _};
 10
 11#[derive(Debug, Clone, Copy)]
 12pub enum RequestKind {
 13    Chat,
 14}
 15
 16/// A message in a [`Thread`].
 17pub struct Message {
 18    pub role: Role,
 19    pub text: String,
 20}
 21
 22struct PendingCompletion {
 23    id: usize,
 24    _task: Task<()>,
 25}
 26
 27/// A thread of conversation with the LLM.
 28pub struct Thread {
 29    messages: Vec<Message>,
 30    completion_count: usize,
 31    pending_completions: Vec<PendingCompletion>,
 32}
 33
 34impl Thread {
 35    pub fn new(_cx: &mut ModelContext<Self>) -> Self {
 36        Self {
 37            messages: Vec::new(),
 38            completion_count: 0,
 39            pending_completions: Vec::new(),
 40        }
 41    }
 42
 43    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 44        self.messages.iter()
 45    }
 46
 47    pub fn insert_user_message(&mut self, text: impl Into<String>) {
 48        self.messages.push(Message {
 49            role: Role::User,
 50            text: text.into(),
 51        });
 52    }
 53
 54    pub fn to_completion_request(
 55        &self,
 56        _request_kind: RequestKind,
 57        _cx: &AppContext,
 58    ) -> LanguageModelRequest {
 59        let mut request = LanguageModelRequest {
 60            messages: vec![],
 61            tools: Vec::new(),
 62            stop: Vec::new(),
 63            temperature: None,
 64        };
 65
 66        for message in &self.messages {
 67            let mut request_message = LanguageModelRequestMessage {
 68                role: message.role,
 69                content: Vec::new(),
 70                cache: false,
 71            };
 72
 73            request_message
 74                .content
 75                .push(MessageContent::Text(message.text.clone()));
 76
 77            request.messages.push(request_message);
 78        }
 79
 80        request
 81    }
 82
 83    pub fn stream_completion(
 84        &mut self,
 85        request: LanguageModelRequest,
 86        model: Arc<dyn LanguageModel>,
 87        cx: &mut ModelContext<Self>,
 88    ) {
 89        let pending_completion_id = post_inc(&mut self.completion_count);
 90
 91        let task = cx.spawn(|thread, mut cx| async move {
 92            let stream = model.stream_completion(request, &cx);
 93            let stream_completion = async {
 94                let mut events = stream.await?;
 95                let mut stop_reason = StopReason::EndTurn;
 96
 97                while let Some(event) = events.next().await {
 98                    let event = event?;
 99
100                    thread.update(&mut cx, |thread, cx| {
101                        match event {
102                            LanguageModelCompletionEvent::StartMessage { .. } => {
103                                thread.messages.push(Message {
104                                    role: Role::Assistant,
105                                    text: String::new(),
106                                });
107                            }
108                            LanguageModelCompletionEvent::Stop(reason) => {
109                                stop_reason = reason;
110                            }
111                            LanguageModelCompletionEvent::Text(chunk) => {
112                                if let Some(last_message) = thread.messages.last_mut() {
113                                    if last_message.role == Role::Assistant {
114                                        last_message.text.push_str(&chunk);
115                                    }
116                                }
117                            }
118                            LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
119                        }
120
121                        cx.emit(ThreadEvent::StreamedCompletion);
122                        cx.notify();
123                    })?;
124
125                    smol::future::yield_now().await;
126                }
127
128                thread.update(&mut cx, |thread, _cx| {
129                    thread
130                        .pending_completions
131                        .retain(|completion| completion.id != pending_completion_id);
132                })?;
133
134                anyhow::Ok(stop_reason)
135            };
136
137            let result = stream_completion.await;
138            let _ = result.log_err();
139        });
140
141        self.pending_completions.push(PendingCompletion {
142            id: pending_completion_id,
143            _task: task,
144        });
145    }
146}
147
148#[derive(Debug, Clone)]
149pub enum ThreadEvent {
150    StreamedCompletion,
151}
152
153impl EventEmitter<ThreadEvent> for Thread {}