1use std::sync::Arc;
2
3use futures::StreamExt as _;
4use gpui::{AppContext, EventEmitter, ModelContext, Task};
5use language_model::{
6 LanguageModel, LanguageModelCompletionEvent, LanguageModelRequest, LanguageModelRequestMessage,
7 MessageContent, Role, StopReason,
8};
9use util::ResultExt as _;
10
11#[derive(Debug, Clone, Copy)]
12pub enum RequestKind {
13 Chat,
14}
15
16/// A message in a [`Thread`].
17pub struct Message {
18 pub role: Role,
19 pub text: String,
20}
21
22/// A thread of conversation with the LLM.
23pub struct Thread {
24 messages: Vec<Message>,
25 pending_completion_tasks: Vec<Task<()>>,
26}
27
28impl Thread {
29 pub fn new(_cx: &mut ModelContext<Self>) -> Self {
30 Self {
31 messages: Vec::new(),
32 pending_completion_tasks: Vec::new(),
33 }
34 }
35
36 pub fn messages(&self) -> impl Iterator<Item = &Message> {
37 self.messages.iter()
38 }
39
40 pub fn insert_user_message(&mut self, text: impl Into<String>) {
41 self.messages.push(Message {
42 role: Role::User,
43 text: text.into(),
44 });
45 }
46
47 pub fn to_completion_request(
48 &self,
49 _request_kind: RequestKind,
50 _cx: &AppContext,
51 ) -> LanguageModelRequest {
52 let mut request = LanguageModelRequest {
53 messages: vec![],
54 tools: Vec::new(),
55 stop: Vec::new(),
56 temperature: None,
57 };
58
59 for message in &self.messages {
60 let mut request_message = LanguageModelRequestMessage {
61 role: message.role,
62 content: Vec::new(),
63 cache: false,
64 };
65
66 request_message
67 .content
68 .push(MessageContent::Text(message.text.clone()));
69
70 request.messages.push(request_message);
71 }
72
73 request
74 }
75
76 pub fn stream_completion(
77 &mut self,
78 request: LanguageModelRequest,
79 model: Arc<dyn LanguageModel>,
80 cx: &mut ModelContext<Self>,
81 ) {
82 let task = cx.spawn(|this, mut cx| async move {
83 let stream = model.stream_completion(request, &cx);
84 let stream_completion = async {
85 let mut events = stream.await?;
86 let mut stop_reason = StopReason::EndTurn;
87
88 while let Some(event) = events.next().await {
89 let event = event?;
90
91 this.update(&mut cx, |thread, cx| {
92 match event {
93 LanguageModelCompletionEvent::StartMessage { .. } => {
94 thread.messages.push(Message {
95 role: Role::Assistant,
96 text: String::new(),
97 });
98 }
99 LanguageModelCompletionEvent::Stop(reason) => {
100 stop_reason = reason;
101 }
102 LanguageModelCompletionEvent::Text(chunk) => {
103 if let Some(last_message) = thread.messages.last_mut() {
104 if last_message.role == Role::Assistant {
105 last_message.text.push_str(&chunk);
106 }
107 }
108 }
109 LanguageModelCompletionEvent::ToolUse(_tool_use) => {}
110 }
111
112 cx.emit(ThreadEvent::StreamedCompletion);
113 cx.notify();
114 })?;
115
116 smol::future::yield_now().await;
117 }
118
119 anyhow::Ok(stop_reason)
120 };
121
122 let result = stream_completion.await;
123 let _ = result.log_err();
124 });
125
126 self.pending_completion_tasks.push(task);
127 }
128}
129
130#[derive(Debug, Clone)]
131pub enum ThreadEvent {
132 StreamedCompletion,
133}
134
135impl EventEmitter<ThreadEvent> for Thread {}