thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
 11    LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12    StopReason,
 13};
 14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
 15use serde::{Deserialize, Serialize};
 16use util::post_inc;
 17
 18#[derive(Debug, Clone, Copy)]
 19pub enum RequestKind {
 20    Chat,
 21}
 22
 23#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 24pub struct MessageId(usize);
 25
 26impl MessageId {
 27    fn post_inc(&mut self) -> Self {
 28        Self(post_inc(&mut self.0))
 29    }
 30}
 31
 32/// A message in a [`Thread`].
 33#[derive(Debug, Clone)]
 34pub struct Message {
 35    pub id: MessageId,
 36    pub role: Role,
 37    pub text: String,
 38}
 39
 40/// A thread of conversation with the LLM.
 41pub struct Thread {
 42    messages: Vec<Message>,
 43    next_message_id: MessageId,
 44    completion_count: usize,
 45    pending_completions: Vec<PendingCompletion>,
 46    tools: Arc<ToolWorkingSet>,
 47    tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 48    tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
 49    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 50}
 51
 52impl Thread {
 53    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
 54        Self {
 55            messages: Vec::new(),
 56            next_message_id: MessageId(0),
 57            completion_count: 0,
 58            pending_completions: Vec::new(),
 59            tools,
 60            tool_uses_by_message: HashMap::default(),
 61            tool_results_by_message: HashMap::default(),
 62            pending_tool_uses_by_id: HashMap::default(),
 63        }
 64    }
 65
 66    pub fn message(&self, id: MessageId) -> Option<&Message> {
 67        self.messages.iter().find(|message| message.id == id)
 68    }
 69
 70    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 71        &self.tools
 72    }
 73
 74    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
 75        self.pending_tool_uses_by_id.values().collect()
 76    }
 77
 78    pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
 79        let id = self.next_message_id.post_inc();
 80        self.messages.push(Message {
 81            id,
 82            role: Role::User,
 83            text: text.into(),
 84        });
 85        cx.emit(ThreadEvent::MessageAdded(id));
 86    }
 87
 88    pub fn to_completion_request(
 89        &self,
 90        _request_kind: RequestKind,
 91        _cx: &AppContext,
 92    ) -> LanguageModelRequest {
 93        let mut request = LanguageModelRequest {
 94            messages: vec![],
 95            tools: Vec::new(),
 96            stop: Vec::new(),
 97            temperature: None,
 98        };
 99
100        for message in &self.messages {
101            let mut request_message = LanguageModelRequestMessage {
102                role: message.role,
103                content: Vec::new(),
104                cache: false,
105            };
106
107            if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
108                for tool_result in tool_results {
109                    request_message
110                        .content
111                        .push(MessageContent::ToolResult(tool_result.clone()));
112                }
113            }
114
115            if !message.text.is_empty() {
116                request_message
117                    .content
118                    .push(MessageContent::Text(message.text.clone()));
119            }
120
121            if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
122                for tool_use in tool_uses {
123                    request_message
124                        .content
125                        .push(MessageContent::ToolUse(tool_use.clone()));
126                }
127            }
128
129            request.messages.push(request_message);
130        }
131
132        request
133    }
134
135    pub fn stream_completion(
136        &mut self,
137        request: LanguageModelRequest,
138        model: Arc<dyn LanguageModel>,
139        cx: &mut ModelContext<Self>,
140    ) {
141        let pending_completion_id = post_inc(&mut self.completion_count);
142
143        let task = cx.spawn(|thread, mut cx| async move {
144            let stream = model.stream_completion(request, &cx);
145            let stream_completion = async {
146                let mut events = stream.await?;
147                let mut stop_reason = StopReason::EndTurn;
148
149                while let Some(event) = events.next().await {
150                    let event = event?;
151
152                    thread.update(&mut cx, |thread, cx| {
153                        match event {
154                            LanguageModelCompletionEvent::StartMessage { .. } => {
155                                let id = thread.next_message_id.post_inc();
156                                thread.messages.push(Message {
157                                    id,
158                                    role: Role::Assistant,
159                                    text: String::new(),
160                                });
161                                cx.emit(ThreadEvent::MessageAdded(id));
162                            }
163                            LanguageModelCompletionEvent::Stop(reason) => {
164                                stop_reason = reason;
165                            }
166                            LanguageModelCompletionEvent::Text(chunk) => {
167                                if let Some(last_message) = thread.messages.last_mut() {
168                                    if last_message.role == Role::Assistant {
169                                        last_message.text.push_str(&chunk);
170                                        cx.emit(ThreadEvent::StreamedAssistantText(
171                                            last_message.id,
172                                            chunk,
173                                        ));
174                                    }
175                                }
176                            }
177                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
178                                if let Some(last_assistant_message) = thread
179                                    .messages
180                                    .iter()
181                                    .rfind(|message| message.role == Role::Assistant)
182                                {
183                                    thread
184                                        .tool_uses_by_message
185                                        .entry(last_assistant_message.id)
186                                        .or_default()
187                                        .push(tool_use.clone());
188
189                                    thread.pending_tool_uses_by_id.insert(
190                                        tool_use.id.clone(),
191                                        PendingToolUse {
192                                            assistant_message_id: last_assistant_message.id,
193                                            id: tool_use.id,
194                                            name: tool_use.name,
195                                            input: tool_use.input,
196                                            status: PendingToolUseStatus::Idle,
197                                        },
198                                    );
199                                }
200                            }
201                        }
202
203                        cx.emit(ThreadEvent::StreamedCompletion);
204                        cx.notify();
205                    })?;
206
207                    smol::future::yield_now().await;
208                }
209
210                thread.update(&mut cx, |thread, _cx| {
211                    thread
212                        .pending_completions
213                        .retain(|completion| completion.id != pending_completion_id);
214                })?;
215
216                anyhow::Ok(stop_reason)
217            };
218
219            let result = stream_completion.await;
220
221            thread
222                .update(&mut cx, |_thread, cx| match result.as_ref() {
223                    Ok(stop_reason) => match stop_reason {
224                        StopReason::ToolUse => {
225                            cx.emit(ThreadEvent::UsePendingTools);
226                        }
227                        StopReason::EndTurn => {}
228                        StopReason::MaxTokens => {}
229                    },
230                    Err(error) => {
231                        if error.is::<PaymentRequiredError>() {
232                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
233                        } else if error.is::<MaxMonthlySpendReachedError>() {
234                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
235                        } else {
236                            let error_message = error
237                                .chain()
238                                .map(|err| err.to_string())
239                                .collect::<Vec<_>>()
240                                .join("\n");
241                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
242                                SharedString::from(error_message.clone()),
243                            )));
244                        }
245                    }
246                })
247                .ok();
248        });
249
250        self.pending_completions.push(PendingCompletion {
251            id: pending_completion_id,
252            _task: task,
253        });
254    }
255
256    pub fn insert_tool_output(
257        &mut self,
258        assistant_message_id: MessageId,
259        tool_use_id: LanguageModelToolUseId,
260        output: Task<Result<String>>,
261        cx: &mut ModelContext<Self>,
262    ) {
263        let insert_output_task = cx.spawn(|thread, mut cx| {
264            let tool_use_id = tool_use_id.clone();
265            async move {
266                let output = output.await;
267                thread
268                    .update(&mut cx, |thread, cx| {
269                        // The tool use was requested by an Assistant message,
270                        // so we want to attach the tool results to the next
271                        // user message.
272                        let next_user_message = MessageId(assistant_message_id.0 + 1);
273
274                        let tool_results = thread
275                            .tool_results_by_message
276                            .entry(next_user_message)
277                            .or_default();
278
279                        match output {
280                            Ok(output) => {
281                                tool_results.push(LanguageModelToolResult {
282                                    tool_use_id: tool_use_id.to_string(),
283                                    content: output,
284                                    is_error: false,
285                                });
286
287                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
288                            }
289                            Err(err) => {
290                                tool_results.push(LanguageModelToolResult {
291                                    tool_use_id: tool_use_id.to_string(),
292                                    content: err.to_string(),
293                                    is_error: true,
294                                });
295
296                                if let Some(tool_use) =
297                                    thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
298                                {
299                                    tool_use.status = PendingToolUseStatus::Error(err.to_string());
300                                }
301                            }
302                        }
303                    })
304                    .ok();
305            }
306        });
307
308        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
309            tool_use.status = PendingToolUseStatus::Running {
310                _task: insert_output_task.shared(),
311            };
312        }
313    }
314}
315
316#[derive(Debug, Clone)]
317pub enum ThreadError {
318    PaymentRequired,
319    MaxMonthlySpendReached,
320    Message(SharedString),
321}
322
323#[derive(Debug, Clone)]
324pub enum ThreadEvent {
325    ShowError(ThreadError),
326    StreamedCompletion,
327    StreamedAssistantText(MessageId, String),
328    MessageAdded(MessageId),
329    UsePendingTools,
330    ToolFinished {
331        #[allow(unused)]
332        tool_use_id: LanguageModelToolUseId,
333    },
334}
335
336impl EventEmitter<ThreadEvent> for Thread {}
337
338struct PendingCompletion {
339    id: usize,
340    _task: Task<()>,
341}
342
343#[derive(Debug, Clone)]
344pub struct PendingToolUse {
345    pub id: LanguageModelToolUseId,
346    /// The ID of the Assistant message in which the tool use was requested.
347    pub assistant_message_id: MessageId,
348    pub name: String,
349    pub input: serde_json::Value,
350    pub status: PendingToolUseStatus,
351}
352
353#[derive(Debug, Clone)]
354pub enum PendingToolUseStatus {
355    Idle,
356    Running { _task: Shared<Task<()>> },
357    Error(#[allow(unused)] String),
358}
359
360impl PendingToolUseStatus {
361    pub fn is_idle(&self) -> bool {
362        matches!(self, PendingToolUseStatus::Idle)
363    }
364}