1use std::io;
2use std::rc::Rc;
3
4use anyhow::{anyhow, Result};
5use editor::Editor;
6use futures::AsyncBufReadExt;
7use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
8use gpui::executor::Foreground;
9use gpui::{actions, AppContext, Task, ViewContext};
10use isahc::prelude::*;
11use isahc::{http::StatusCode, Request};
12use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
13use serde::{Deserialize, Serialize};
14use util::ResultExt;
15
16actions!(ai, [Assist]);
17
18// Data types for chat completion requests
19#[derive(Serialize)]
20struct OpenAIRequest {
21 model: String,
22 messages: Vec<RequestMessage>,
23 stream: bool,
24}
25
26#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
27struct RequestMessage {
28 role: Role,
29 content: String,
30}
31
32#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
33struct ResponseMessage {
34 role: Option<Role>,
35 content: Option<String>,
36}
37
38#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
39#[serde(rename_all = "lowercase")]
40enum Role {
41 User,
42 Assistant,
43 System,
44}
45
46#[derive(Deserialize, Debug)]
47struct OpenAIResponseStreamEvent {
48 pub id: Option<String>,
49 pub object: String,
50 pub created: u32,
51 pub model: String,
52 pub choices: Vec<ChatChoiceDelta>,
53 pub usage: Option<Usage>,
54}
55
56#[derive(Deserialize, Debug)]
57struct Usage {
58 pub prompt_tokens: u32,
59 pub completion_tokens: u32,
60 pub total_tokens: u32,
61}
62
63#[derive(Deserialize, Debug)]
64struct ChatChoiceDelta {
65 pub index: u32,
66 pub delta: ResponseMessage,
67 pub finish_reason: Option<String>,
68}
69
70#[derive(Deserialize, Debug)]
71struct OpenAIUsage {
72 prompt_tokens: u64,
73 completion_tokens: u64,
74 total_tokens: u64,
75}
76
77#[derive(Deserialize, Debug)]
78struct OpenAIChoice {
79 text: String,
80 index: u32,
81 logprobs: Option<serde_json::Value>,
82 finish_reason: Option<String>,
83}
84
85pub fn init(cx: &mut AppContext) {
86 cx.add_async_action(assist)
87}
88
89fn assist(
90 editor: &mut Editor,
91 _: &Assist,
92 cx: &mut ViewContext<Editor>,
93) -> Option<Task<Result<()>>> {
94 let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
95
96 let markdown = editor.text(cx);
97 let prompt = parse_dialog(&markdown);
98 let response = stream_completion(api_key, prompt, cx.foreground().clone());
99
100 let range = editor.buffer().update(cx, |buffer, cx| {
101 let snapshot = buffer.snapshot(cx);
102 let chars = snapshot.reversed_chars_at(snapshot.len());
103 let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count();
104 let suffix = "\n".repeat(2 - trailing_newlines);
105 let end = snapshot.len();
106 buffer.edit([(end..end, suffix.clone())], None, cx);
107 let snapshot = buffer.snapshot(cx);
108 let start = snapshot.anchor_before(snapshot.len());
109 let end = snapshot.anchor_after(snapshot.len());
110 start..end
111 });
112 let buffer = editor.buffer().clone();
113
114 Some(cx.spawn(|_, mut cx| async move {
115 let mut stream = response.await?;
116 let mut message = String::new();
117 while let Some(stream_event) = stream.next().await {
118 if let Some(choice) = stream_event?.choices.first() {
119 if let Some(content) = &choice.delta.content {
120 message.push_str(content);
121 }
122 }
123
124 buffer.update(&mut cx, |buffer, cx| {
125 buffer.edit([(range.clone(), message.clone())], None, cx);
126 });
127 }
128 Ok(())
129 }))
130}
131
132fn parse_dialog(markdown: &str) -> OpenAIRequest {
133 let parser = Parser::new(markdown);
134 let mut messages = Vec::new();
135
136 let mut current_role: Option<Role> = None;
137 let mut buffer = String::new();
138 for event in parser {
139 match event {
140 Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
141 if let Some(role) = current_role.take() {
142 if !buffer.is_empty() {
143 messages.push(RequestMessage {
144 role,
145 content: buffer.trim().to_string(),
146 });
147 buffer.clear();
148 }
149 }
150 }
151 Event::Text(text) => {
152 if current_role.is_some() {
153 buffer.push_str(&text);
154 } else {
155 // Determine the current role based on the H2 header text
156 let text = text.to_lowercase();
157 current_role = if text.contains("user") {
158 Some(Role::User)
159 } else if text.contains("assistant") {
160 Some(Role::Assistant)
161 } else if text.contains("system") {
162 Some(Role::System)
163 } else {
164 None
165 };
166 }
167 }
168 _ => (),
169 }
170 }
171 if let Some(role) = current_role {
172 messages.push(RequestMessage {
173 role,
174 content: buffer,
175 });
176 }
177
178 OpenAIRequest {
179 model: "gpt-4".into(),
180 messages,
181 stream: true,
182 }
183}
184
185async fn stream_completion(
186 api_key: String,
187 mut request: OpenAIRequest,
188 executor: Rc<Foreground>,
189) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
190 request.stream = true;
191
192 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
193
194 let json_data = serde_json::to_string(&request)?;
195 let mut response = Request::post("https://api.openai.com/v1/chat/completions")
196 .header("Content-Type", "application/json")
197 .header("Authorization", format!("Bearer {}", api_key))
198 .body(json_data)?
199 .send_async()
200 .await?;
201
202 let status = response.status();
203 if status == StatusCode::OK {
204 executor
205 .spawn(async move {
206 let mut lines = BufReader::new(response.body_mut()).lines();
207
208 fn parse_line(
209 line: Result<String, io::Error>,
210 ) -> Result<Option<OpenAIResponseStreamEvent>> {
211 if let Some(data) = line?.strip_prefix("data: ") {
212 let event = serde_json::from_str(&data)?;
213 Ok(Some(event))
214 } else {
215 Ok(None)
216 }
217 }
218
219 while let Some(line) = lines.next().await {
220 if let Some(event) = parse_line(line).transpose() {
221 tx.unbounded_send(event).log_err();
222 }
223 }
224
225 anyhow::Ok(())
226 })
227 .detach();
228
229 Ok(rx)
230 } else {
231 let mut body = String::new();
232 response.body_mut().read_to_string(&mut body).await?;
233
234 Err(anyhow!(
235 "Failed to connect to OpenAI API: {} {}",
236 response.status(),
237 body,
238 ))
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_parse_dialog() {
248 use unindent::Unindent;
249
250 let test_input = r#"
251 ## System
252 Hey there, welcome to Zed!
253
254 ## Assintant
255 Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
256 "#.unindent();
257
258 let expected_output = vec![
259 RequestMessage {
260 role: Role::User,
261 content: "Hey there, welcome to Zed!".to_string(),
262 },
263 RequestMessage {
264 role: Role::Assistant,
265 content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
266 },
267 ];
268
269 assert_eq!(parse_dialog(&test_input).messages, expected_output);
270 }
271}