ai.rs

  1use anyhow::{anyhow, Result};
  2use editor::Editor;
  3use futures::AsyncBufReadExt;
  4use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
  5use gpui::executor::Background;
  6use gpui::{actions, AppContext, Task, ViewContext};
  7use indoc::indoc;
  8use isahc::prelude::*;
  9use isahc::{http::StatusCode, Request};
 10use serde::{Deserialize, Serialize};
 11use std::{io, sync::Arc};
 12use util::ResultExt;
 13
 14actions!(ai, [Assist]);
 15
 16// Data types for chat completion requests
 17#[derive(Serialize)]
 18struct OpenAIRequest {
 19    model: String,
 20    messages: Vec<RequestMessage>,
 21    stream: bool,
 22}
 23
 24#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 25struct RequestMessage {
 26    role: Role,
 27    content: String,
 28}
 29
 30#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 31struct ResponseMessage {
 32    role: Option<Role>,
 33    content: Option<String>,
 34}
 35
 36#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 37#[serde(rename_all = "lowercase")]
 38enum Role {
 39    User,
 40    Assistant,
 41    System,
 42}
 43
 44#[derive(Deserialize, Debug)]
 45struct OpenAIResponseStreamEvent {
 46    pub id: Option<String>,
 47    pub object: String,
 48    pub created: u32,
 49    pub model: String,
 50    pub choices: Vec<ChatChoiceDelta>,
 51    pub usage: Option<Usage>,
 52}
 53
 54#[derive(Deserialize, Debug)]
 55struct Usage {
 56    pub prompt_tokens: u32,
 57    pub completion_tokens: u32,
 58    pub total_tokens: u32,
 59}
 60
 61#[derive(Deserialize, Debug)]
 62struct ChatChoiceDelta {
 63    pub index: u32,
 64    pub delta: ResponseMessage,
 65    pub finish_reason: Option<String>,
 66}
 67
 68#[derive(Deserialize, Debug)]
 69struct OpenAIUsage {
 70    prompt_tokens: u64,
 71    completion_tokens: u64,
 72    total_tokens: u64,
 73}
 74
 75#[derive(Deserialize, Debug)]
 76struct OpenAIChoice {
 77    text: String,
 78    index: u32,
 79    logprobs: Option<serde_json::Value>,
 80    finish_reason: Option<String>,
 81}
 82
 83pub fn init(cx: &mut AppContext) {
 84    cx.add_async_action(assist)
 85}
 86
 87fn assist(
 88    editor: &mut Editor,
 89    _: &Assist,
 90    cx: &mut ViewContext<Editor>,
 91) -> Option<Task<Result<()>>> {
 92    let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
 93
 94    const SYSTEM_MESSAGE: &'static str = indoc! {r#"
 95        You an AI language model embedded in a code editor named Zed, authored by Zed Industries.
 96        The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor.
 97        A model mention is indicated via a leading / on a line.
 98        The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text.
 99        In this sentence, the word ->->example<-<- is selected.
100        Respond to any selected model mention.
101
102        Wrap your responses in > < as follows.
103        / What do you think?
104        > I think that's a great idea. <
105
106        For lines that are likely to wrap, or multiline responses, start and end the > and < on their own lines.
107        >
108        I think that's a great idea
109        <
110
111        If the selected mention is not at the end of the document, briefly summarize the context.
112        > Key ideas of generative programming:
113        * Managing context
114            * Managing length
115            * Context distillation
116                - Shrink a context's size without loss of meaning.
117        * Fine-grained version control
118            * Portals to other contexts
119                * Distillation policies
120                * Budgets
121        <
122    "#};
123
124    let selections = editor.selections.all(cx);
125    let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
126        // Insert ->-> <-<- around selected text as described in the system prompt above.
127        let snapshot = buffer.snapshot(cx);
128        let mut user_message = String::new();
129        let mut buffer_offset = 0;
130        for selection in selections {
131            user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
132            user_message.push_str("->->");
133            user_message.extend(snapshot.text_for_range(selection.start..selection.end));
134            buffer_offset = selection.end;
135            user_message.push_str("<-<-");
136        }
137        if buffer_offset < snapshot.len() {
138            user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
139        }
140
141        // Ensure the document ends with 4 trailing newlines.
142        let trailing_newline_count = snapshot
143            .reversed_chars_at(snapshot.len())
144            .take_while(|c| *c == '\n')
145            .take(4);
146        let suffix = "\n".repeat(4 - trailing_newline_count.count());
147        buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
148
149        let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
150        let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
151
152        (user_message, insertion_site)
153    });
154
155    let stream = stream_completion(
156        api_key,
157        cx.background_executor().clone(),
158        OpenAIRequest {
159            model: "gpt-4".to_string(),
160            messages: vec![
161                RequestMessage {
162                    role: Role::System,
163                    content: SYSTEM_MESSAGE.to_string(),
164                },
165                RequestMessage {
166                    role: Role::User,
167                    content: user_message,
168                },
169            ],
170            stream: false,
171        },
172    );
173    let buffer = editor.buffer().clone();
174    Some(cx.spawn(|_, mut cx| async move {
175        let mut messages = stream.await?;
176        while let Some(message) = messages.next().await {
177            let mut message = message?;
178            if let Some(choice) = message.choices.pop() {
179                buffer.update(&mut cx, |buffer, cx| {
180                    let text: Arc<str> = choice.delta.content?.into();
181                    buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
182                    Some(())
183                });
184            }
185        }
186        Ok(())
187    }))
188}
189
190async fn stream_completion(
191    api_key: String,
192    executor: Arc<Background>,
193    mut request: OpenAIRequest,
194) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
195    request.stream = true;
196
197    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
198
199    let json_data = serde_json::to_string(&request)?;
200    let mut response = Request::post("https://api.openai.com/v1/chat/completions")
201        .header("Content-Type", "application/json")
202        .header("Authorization", format!("Bearer {}", api_key))
203        .body(json_data)?
204        .send_async()
205        .await?;
206
207    let status = response.status();
208    if status == StatusCode::OK {
209        executor
210            .spawn(async move {
211                let mut lines = BufReader::new(response.body_mut()).lines();
212
213                fn parse_line(
214                    line: Result<String, io::Error>,
215                ) -> Result<Option<OpenAIResponseStreamEvent>> {
216                    if let Some(data) = line?.strip_prefix("data: ") {
217                        let event = serde_json::from_str(&data)?;
218                        Ok(Some(event))
219                    } else {
220                        Ok(None)
221                    }
222                }
223
224                while let Some(line) = lines.next().await {
225                    if let Some(event) = parse_line(line).transpose() {
226                        tx.unbounded_send(event).log_err();
227                    }
228                }
229
230                anyhow::Ok(())
231            })
232            .detach();
233
234        Ok(rx)
235    } else {
236        let mut body = String::new();
237        response.body_mut().read_to_string(&mut body).await?;
238
239        Err(anyhow!(
240            "Failed to connect to OpenAI API: {} {}",
241            response.status(),
242            body,
243        ))
244    }
245}
246
247#[cfg(test)]
248mod tests {}