ai.rs

  1use std::io;
  2use std::rc::Rc;
  3
  4use anyhow::{anyhow, Result};
  5use editor::Editor;
  6use futures::AsyncBufReadExt;
  7use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
  8use gpui::executor::Foreground;
  9use gpui::{actions, AppContext, Task, ViewContext};
 10use isahc::prelude::*;
 11use isahc::{http::StatusCode, Request};
 12use pulldown_cmark::{Event, HeadingLevel, Parser, Tag};
 13use serde::{Deserialize, Serialize};
 14use util::ResultExt;
 15
 16actions!(ai, [Assist]);
 17
 18// Data types for chat completion requests
 19#[derive(Serialize)]
 20struct OpenAIRequest {
 21    model: String,
 22    messages: Vec<RequestMessage>,
 23    stream: bool,
 24}
 25
 26#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 27struct RequestMessage {
 28    role: Role,
 29    content: String,
 30}
 31
 32#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 33struct ResponseMessage {
 34    role: Option<Role>,
 35    content: Option<String>,
 36}
 37
 38#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 39#[serde(rename_all = "lowercase")]
 40enum Role {
 41    User,
 42    Assistant,
 43    System,
 44}
 45
 46#[derive(Deserialize, Debug)]
 47struct OpenAIResponseStreamEvent {
 48    pub id: Option<String>,
 49    pub object: String,
 50    pub created: u32,
 51    pub model: String,
 52    pub choices: Vec<ChatChoiceDelta>,
 53    pub usage: Option<Usage>,
 54}
 55
 56#[derive(Deserialize, Debug)]
 57struct Usage {
 58    pub prompt_tokens: u32,
 59    pub completion_tokens: u32,
 60    pub total_tokens: u32,
 61}
 62
 63#[derive(Deserialize, Debug)]
 64struct ChatChoiceDelta {
 65    pub index: u32,
 66    pub delta: ResponseMessage,
 67    pub finish_reason: Option<String>,
 68}
 69
 70#[derive(Deserialize, Debug)]
 71struct OpenAIUsage {
 72    prompt_tokens: u64,
 73    completion_tokens: u64,
 74    total_tokens: u64,
 75}
 76
 77#[derive(Deserialize, Debug)]
 78struct OpenAIChoice {
 79    text: String,
 80    index: u32,
 81    logprobs: Option<serde_json::Value>,
 82    finish_reason: Option<String>,
 83}
 84
 85pub fn init(cx: &mut AppContext) {
 86    cx.add_async_action(assist)
 87}
 88
 89fn assist(
 90    editor: &mut Editor,
 91    _: &Assist,
 92    cx: &mut ViewContext<Editor>,
 93) -> Option<Task<Result<()>>> {
 94    let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
 95
 96    let markdown = editor.text(cx);
 97    let prompt = parse_dialog(&markdown);
 98    let response = stream_completion(api_key, prompt, cx.foreground().clone());
 99
100    let range = editor.buffer().update(cx, |buffer, cx| {
101        let snapshot = buffer.snapshot(cx);
102        let chars = snapshot.reversed_chars_at(snapshot.len());
103        let trailing_newlines = chars.take(2).take_while(|c| *c == '\n').count();
104        let suffix = "\n".repeat(2 - trailing_newlines);
105        let end = snapshot.len();
106        buffer.edit([(end..end, suffix.clone())], None, cx);
107        let snapshot = buffer.snapshot(cx);
108        let start = snapshot.anchor_before(snapshot.len());
109        let end = snapshot.anchor_after(snapshot.len());
110        start..end
111    });
112    let buffer = editor.buffer().clone();
113
114    Some(cx.spawn(|_, mut cx| async move {
115        let mut stream = response.await?;
116        let mut message = String::new();
117        while let Some(stream_event) = stream.next().await {
118            if let Some(choice) = stream_event?.choices.first() {
119                if let Some(content) = &choice.delta.content {
120                    message.push_str(content);
121                }
122            }
123
124            buffer.update(&mut cx, |buffer, cx| {
125                buffer.edit([(range.clone(), message.clone())], None, cx);
126            });
127        }
128        Ok(())
129    }))
130}
131
132fn parse_dialog(markdown: &str) -> OpenAIRequest {
133    let parser = Parser::new(markdown);
134    let mut messages = Vec::new();
135
136    let mut current_role: Option<Role> = None;
137    let mut buffer = String::new();
138    for event in parser {
139        match event {
140            Event::Start(Tag::Heading(HeadingLevel::H2, _, _)) => {
141                if let Some(role) = current_role.take() {
142                    if !buffer.is_empty() {
143                        messages.push(RequestMessage {
144                            role,
145                            content: buffer.trim().to_string(),
146                        });
147                        buffer.clear();
148                    }
149                }
150            }
151            Event::Text(text) => {
152                if current_role.is_some() {
153                    buffer.push_str(&text);
154                } else {
155                    // Determine the current role based on the H2 header text
156                    let text = text.to_lowercase();
157                    current_role = if text.contains("user") {
158                        Some(Role::User)
159                    } else if text.contains("assistant") {
160                        Some(Role::Assistant)
161                    } else if text.contains("system") {
162                        Some(Role::System)
163                    } else {
164                        None
165                    };
166                }
167            }
168            _ => (),
169        }
170    }
171    if let Some(role) = current_role {
172        messages.push(RequestMessage {
173            role,
174            content: buffer,
175        });
176    }
177
178    OpenAIRequest {
179        model: "gpt-4".into(),
180        messages,
181        stream: true,
182    }
183}
184
185async fn stream_completion(
186    api_key: String,
187    mut request: OpenAIRequest,
188    executor: Rc<Foreground>,
189) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
190    request.stream = true;
191
192    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
193
194    let json_data = serde_json::to_string(&request)?;
195    let mut response = Request::post("https://api.openai.com/v1/chat/completions")
196        .header("Content-Type", "application/json")
197        .header("Authorization", format!("Bearer {}", api_key))
198        .body(json_data)?
199        .send_async()
200        .await?;
201
202    let status = response.status();
203    if status == StatusCode::OK {
204        executor
205            .spawn(async move {
206                let mut lines = BufReader::new(response.body_mut()).lines();
207
208                fn parse_line(
209                    line: Result<String, io::Error>,
210                ) -> Result<Option<OpenAIResponseStreamEvent>> {
211                    if let Some(data) = line?.strip_prefix("data: ") {
212                        let event = serde_json::from_str(&data)?;
213                        Ok(Some(event))
214                    } else {
215                        Ok(None)
216                    }
217                }
218
219                while let Some(line) = lines.next().await {
220                    if let Some(event) = parse_line(line).transpose() {
221                        tx.unbounded_send(event).log_err();
222                    }
223                }
224
225                anyhow::Ok(())
226            })
227            .detach();
228
229        Ok(rx)
230    } else {
231        let mut body = String::new();
232        response.body_mut().read_to_string(&mut body).await?;
233
234        Err(anyhow!(
235            "Failed to connect to OpenAI API: {} {}",
236            response.status(),
237            body,
238        ))
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_parse_dialog() {
248        use unindent::Unindent;
249
250        let test_input = r#"
251            ## System
252            Hey there, welcome to Zed!
253
254            ## Assintant
255            Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.
256        "#.unindent();
257
258        let expected_output = vec![
259            RequestMessage {
260                role: Role::User,
261                content: "Hey there, welcome to Zed!".to_string(),
262            },
263            RequestMessage {
264                role: Role::Assistant,
265                content: "Thanks! I'm excited to be here. I have much to learn, but also much to teach, and I'm growing fast.".to_string(),
266            },
267        ];
268
269        assert_eq!(parse_dialog(&test_input).messages, expected_output);
270    }
271}