thread.rs

  1use std::sync::Arc;
  2
  3use anyhow::Result;
  4use assistant_tool::ToolWorkingSet;
  5use collections::HashMap;
  6use futures::future::Shared;
  7use futures::{FutureExt as _, StreamExt as _};
  8use gpui::{AppContext, EventEmitter, ModelContext, SharedString, Task};
  9use language_model::{
 10    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
 11    LanguageModelToolResult, LanguageModelToolUse, LanguageModelToolUseId, MessageContent, Role,
 12    StopReason,
 13};
 14use language_models::provider::cloud::{MaxMonthlySpendReachedError, PaymentRequiredError};
 15use serde::{Deserialize, Serialize};
 16use util::post_inc;
 17
 18#[derive(Debug, Clone, Copy)]
 19pub enum RequestKind {
 20    Chat,
 21}
 22
 23#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)]
 24pub struct MessageId(usize);
 25
 26impl MessageId {
 27    fn post_inc(&mut self) -> Self {
 28        Self(post_inc(&mut self.0))
 29    }
 30}
 31
 32/// A message in a [`Thread`].
 33#[derive(Debug, Clone)]
 34pub struct Message {
 35    pub id: MessageId,
 36    pub role: Role,
 37    pub text: String,
 38}
 39
 40/// A thread of conversation with the LLM.
 41pub struct Thread {
 42    messages: Vec<Message>,
 43    next_message_id: MessageId,
 44    completion_count: usize,
 45    pending_completions: Vec<PendingCompletion>,
 46    tools: Arc<ToolWorkingSet>,
 47    tool_uses_by_message: HashMap<MessageId, Vec<LanguageModelToolUse>>,
 48    tool_results_by_message: HashMap<MessageId, Vec<LanguageModelToolResult>>,
 49    pending_tool_uses_by_id: HashMap<LanguageModelToolUseId, PendingToolUse>,
 50}
 51
 52impl Thread {
 53    pub fn new(tools: Arc<ToolWorkingSet>, _cx: &mut ModelContext<Self>) -> Self {
 54        Self {
 55            messages: Vec::new(),
 56            next_message_id: MessageId(0),
 57            completion_count: 0,
 58            pending_completions: Vec::new(),
 59            tools,
 60            tool_uses_by_message: HashMap::default(),
 61            tool_results_by_message: HashMap::default(),
 62            pending_tool_uses_by_id: HashMap::default(),
 63        }
 64    }
 65
 66    pub fn message(&self, id: MessageId) -> Option<&Message> {
 67        self.messages.iter().find(|message| message.id == id)
 68    }
 69
 70    pub fn tools(&self) -> &Arc<ToolWorkingSet> {
 71        &self.tools
 72    }
 73
 74    pub fn pending_tool_uses(&self) -> Vec<&PendingToolUse> {
 75        self.pending_tool_uses_by_id.values().collect()
 76    }
 77
 78    pub fn insert_user_message(&mut self, text: impl Into<String>, cx: &mut ModelContext<Self>) {
 79        let id = self.next_message_id.post_inc();
 80        self.messages.push(Message {
 81            id,
 82            role: Role::User,
 83            text: text.into(),
 84        });
 85        cx.emit(ThreadEvent::MessageAdded(id));
 86    }
 87
 88    pub fn to_completion_request(
 89        &self,
 90        _request_kind: RequestKind,
 91        _cx: &AppContext,
 92    ) -> LanguageModelRequest {
 93        let mut request = LanguageModelRequest {
 94            messages: vec![],
 95            tools: Vec::new(),
 96            stop: Vec::new(),
 97            temperature: None,
 98        };
 99
100        for message in &self.messages {
101            let mut request_message = LanguageModelRequestMessage {
102                role: message.role,
103                content: Vec::new(),
104                cache: false,
105            };
106
107            if let Some(tool_results) = self.tool_results_by_message.get(&message.id) {
108                for tool_result in tool_results {
109                    request_message
110                        .content
111                        .push(MessageContent::ToolResult(tool_result.clone()));
112                }
113            }
114
115            if !message.text.is_empty() {
116                request_message
117                    .content
118                    .push(MessageContent::Text(message.text.clone()));
119            }
120
121            if let Some(tool_uses) = self.tool_uses_by_message.get(&message.id) {
122                for tool_use in tool_uses {
123                    request_message
124                        .content
125                        .push(MessageContent::ToolUse(tool_use.clone()));
126                }
127            }
128
129            request.messages.push(request_message);
130        }
131
132        request
133    }
134
135    pub fn stream_completion(
136        &mut self,
137        request: LanguageModelRequest,
138        model: Arc<dyn LanguageModel>,
139        cx: &mut ModelContext<Self>,
140    ) {
141        let pending_completion_id = post_inc(&mut self.completion_count);
142
143        let task = cx.spawn(|thread, mut cx| async move {
144            let stream = model.stream_completion(request, &cx);
145            let stream_completion = async {
146                let mut events = stream.await?;
147                let mut stop_reason = StopReason::EndTurn;
148
149                while let Some(event) = events.next().await {
150                    let event = event?;
151
152                    thread.update(&mut cx, |thread, cx| {
153                        match event {
154                            LanguageModelCompletionEvent::StartMessage { .. } => {
155                                let id = thread.next_message_id.post_inc();
156                                thread.messages.push(Message {
157                                    id,
158                                    role: Role::Assistant,
159                                    text: String::new(),
160                                });
161                                cx.emit(ThreadEvent::MessageAdded(id));
162                            }
163                            LanguageModelCompletionEvent::Stop(reason) => {
164                                stop_reason = reason;
165                            }
166                            LanguageModelCompletionEvent::Text(chunk) => {
167                                if let Some(last_message) = thread.messages.last_mut() {
168                                    if last_message.role == Role::Assistant {
169                                        last_message.text.push_str(&chunk);
170                                    }
171                                }
172                            }
173                            LanguageModelCompletionEvent::ToolUse(tool_use) => {
174                                if let Some(last_assistant_message) = thread
175                                    .messages
176                                    .iter()
177                                    .rfind(|message| message.role == Role::Assistant)
178                                {
179                                    thread
180                                        .tool_uses_by_message
181                                        .entry(last_assistant_message.id)
182                                        .or_default()
183                                        .push(tool_use.clone());
184
185                                    thread.pending_tool_uses_by_id.insert(
186                                        tool_use.id.clone(),
187                                        PendingToolUse {
188                                            assistant_message_id: last_assistant_message.id,
189                                            id: tool_use.id,
190                                            name: tool_use.name,
191                                            input: tool_use.input,
192                                            status: PendingToolUseStatus::Idle,
193                                        },
194                                    );
195                                }
196                            }
197                        }
198
199                        cx.emit(ThreadEvent::StreamedCompletion);
200                        cx.notify();
201                    })?;
202
203                    smol::future::yield_now().await;
204                }
205
206                thread.update(&mut cx, |thread, _cx| {
207                    thread
208                        .pending_completions
209                        .retain(|completion| completion.id != pending_completion_id);
210                })?;
211
212                anyhow::Ok(stop_reason)
213            };
214
215            let result = stream_completion.await;
216
217            thread
218                .update(&mut cx, |_thread, cx| match result.as_ref() {
219                    Ok(stop_reason) => match stop_reason {
220                        StopReason::ToolUse => {
221                            cx.emit(ThreadEvent::UsePendingTools);
222                        }
223                        StopReason::EndTurn => {}
224                        StopReason::MaxTokens => {}
225                    },
226                    Err(error) => {
227                        if error.is::<PaymentRequiredError>() {
228                            cx.emit(ThreadEvent::ShowError(ThreadError::PaymentRequired));
229                        } else if error.is::<MaxMonthlySpendReachedError>() {
230                            cx.emit(ThreadEvent::ShowError(ThreadError::MaxMonthlySpendReached));
231                        } else {
232                            let error_message = error
233                                .chain()
234                                .map(|err| err.to_string())
235                                .collect::<Vec<_>>()
236                                .join("\n");
237                            cx.emit(ThreadEvent::ShowError(ThreadError::Message(
238                                SharedString::from(error_message.clone()),
239                            )));
240                        }
241                    }
242                })
243                .ok();
244        });
245
246        self.pending_completions.push(PendingCompletion {
247            id: pending_completion_id,
248            _task: task,
249        });
250    }
251
252    pub fn insert_tool_output(
253        &mut self,
254        assistant_message_id: MessageId,
255        tool_use_id: LanguageModelToolUseId,
256        output: Task<Result<String>>,
257        cx: &mut ModelContext<Self>,
258    ) {
259        let insert_output_task = cx.spawn(|thread, mut cx| {
260            let tool_use_id = tool_use_id.clone();
261            async move {
262                let output = output.await;
263                thread
264                    .update(&mut cx, |thread, cx| {
265                        // The tool use was requested by an Assistant message,
266                        // so we want to attach the tool results to the next
267                        // user message.
268                        let next_user_message = MessageId(assistant_message_id.0 + 1);
269
270                        let tool_results = thread
271                            .tool_results_by_message
272                            .entry(next_user_message)
273                            .or_default();
274
275                        match output {
276                            Ok(output) => {
277                                tool_results.push(LanguageModelToolResult {
278                                    tool_use_id: tool_use_id.to_string(),
279                                    content: output,
280                                    is_error: false,
281                                });
282
283                                cx.emit(ThreadEvent::ToolFinished { tool_use_id });
284                            }
285                            Err(err) => {
286                                tool_results.push(LanguageModelToolResult {
287                                    tool_use_id: tool_use_id.to_string(),
288                                    content: err.to_string(),
289                                    is_error: true,
290                                });
291
292                                if let Some(tool_use) =
293                                    thread.pending_tool_uses_by_id.get_mut(&tool_use_id)
294                                {
295                                    tool_use.status = PendingToolUseStatus::Error(err.to_string());
296                                }
297                            }
298                        }
299                    })
300                    .ok();
301            }
302        });
303
304        if let Some(tool_use) = self.pending_tool_uses_by_id.get_mut(&tool_use_id) {
305            tool_use.status = PendingToolUseStatus::Running {
306                _task: insert_output_task.shared(),
307            };
308        }
309    }
310}
311
312#[derive(Debug, Clone)]
313pub enum ThreadError {
314    PaymentRequired,
315    MaxMonthlySpendReached,
316    Message(SharedString),
317}
318
319#[derive(Debug, Clone)]
320pub enum ThreadEvent {
321    ShowError(ThreadError),
322    StreamedCompletion,
323    MessageAdded(MessageId),
324    UsePendingTools,
325    ToolFinished {
326        #[allow(unused)]
327        tool_use_id: LanguageModelToolUseId,
328    },
329}
330
331impl EventEmitter<ThreadEvent> for Thread {}
332
333struct PendingCompletion {
334    id: usize,
335    _task: Task<()>,
336}
337
338#[derive(Debug, Clone)]
339pub struct PendingToolUse {
340    pub id: LanguageModelToolUseId,
341    /// The ID of the Assistant message in which the tool use was requested.
342    pub assistant_message_id: MessageId,
343    pub name: String,
344    pub input: serde_json::Value,
345    pub status: PendingToolUseStatus,
346}
347
348#[derive(Debug, Clone)]
349pub enum PendingToolUseStatus {
350    Idle,
351    Running { _task: Shared<Task<()>> },
352    Error(#[allow(unused)] String),
353}
354
355impl PendingToolUseStatus {
356    pub fn is_idle(&self) -> bool {
357        matches!(self, PendingToolUseStatus::Idle)
358    }
359}