thread.rs

  1use std::sync::Arc;
  2
  3use futures::StreamExt as _;
  4use gpui::{AppContext, EventEmitter, ModelContext, Task};
  5use language_model::{
  6    LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
  7    MessageContent, Role, StopReason,
  8};
  9use util::ResultExt as _;
 10
 11#[derive(Debug, Clone, Copy)]
 12pub enum RequestKind {
 13    Chat,
 14}
 15
 16/// A message in a [`Thread`].
 17pub struct Message {
 18    pub role: Role,
 19    pub text: String,
 20}
 21
 22/// A thread of conversation with the LLM.
 23pub struct Thread {
 24    messages: Vec<Message>,
 25    pending_completion_tasks: Vec<Task<()>>,
 26}
 27
 28impl Thread {
 29    pub fn new(_cx: &mut ModelContext<Self>) -> Self {
 30        Self {
 31            messages: Vec::new(),
 32            pending_completion_tasks: Vec::new(),
 33        }
 34    }
 35
 36    pub fn messages(&self) -> impl Iterator<Item = &Message> {
 37        self.messages.iter()
 38    }
 39
 40    pub fn insert_user_message(&mut self, text: impl Into<String>) {
 41        self.messages.push(Message {
 42            role: Role::User,
 43            text: text.into(),
 44        });
 45    }
 46
 47    pub fn to_completion_request(
 48        &self,
 49        _request_kind: RequestKind,
 50        _cx: &AppContext,
 51    ) -> LanguageModelRequest {
 52        let mut request = LanguageModelRequest {
 53            messages: vec![],
 54            tools: Vec::new(),
 55            stop: Vec::new(),
 56            temperature: None,
 57        };
 58
 59        for message in &self.messages {
 60            let mut request_message = LanguageModelRequestMessage {
 61                role: message.role,
 62                content: Vec::new(),
 63                cache: false,
 64            };
 65
 66            request_message
 67                .content
 68                .push(MessageContent::Text(message.text.clone()));
 69
 70            request.messages.push(request_message);
 71        }
 72
 73        request
 74    }
 75
 76    pub fn stream_completion(
 77        &mut self,
 78        request: LanguageModelRequest,
 79        model: Arc<dyn LanguageModel>,
 80        cx: &mut ModelContext<Self>,
 81    ) {
 82        let task = cx.spawn(|this, mut cx| async move {
 83            let stream = model.stream_completion(request, &cx);
 84            let stream_completion = async {
 85                let mut events = stream.await?;
 86                let mut stop_reason = StopReason::EndTurn;
 87
 88                while let Some(event) = events.next().await {
 89                    let event = event?;
 90
 91                    this.update(&mut cx, |thread, cx| {
 92                        match event {
 93                            LanguageModelCompletionEvent::StartMessage { .. } => {
 94                                thread.messages.push(Message {
 95                                    role: Role::Assistant,
 96                                    text: String::new(),
 97                                });
 98                            }
 99                            LanguageModelCompletionEvent::Stop(reason) => {
100                                stop_reason = reason;
101                            }
102                            LanguageModelCompletionEvent::Text(chunk) => {
103                                if let Some(last_message) = thread.messages.last_mut() {
104                                    if last_message.role == Role::Assistant {
105                                        last_message.text.push_str(&chunk);
106                                    }
107                                }
108                            }
109                            LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
110                        }
111
112                        cx.emit(ThreadEvent::StreamedCompletion);
113                        cx.notify();
114                    })?;
115
116                    smol::future::yield_now().await;
117                }
118
119                anyhow::Ok(stop_reason)
120            };
121
122            let result = stream_completion.await;
123            let _ = result.log_err();
124        });
125
126        self.pending_completion_tasks.push(task);
127    }
128}
129
130#[derive(Debug, Clone)]
131pub enum ThreadEvent {
132    StreamedCompletion,
133}
134
135impl EventEmitter<ThreadEvent> for Thread {}