ai.rs

  1use anyhow::{anyhow, Result};
  2use assets::Assets;
  3use collections::HashMap;
  4use editor::Editor;
  5use futures::AsyncBufReadExt;
  6use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
  7use gpui::executor::Background;
  8use gpui::{actions, AppContext, Task, ViewContext};
  9use isahc::prelude::*;
 10use isahc::{http::StatusCode, Request};
 11use serde::{Deserialize, Serialize};
 12use std::cell::RefCell;
 13use std::fs;
 14use std::rc::Rc;
 15use std::{io, sync::Arc};
 16use util::channel::{ReleaseChannel, RELEASE_CHANNEL};
 17use util::{ResultExt, TryFutureExt};
 18
 19actions!(ai, [Assist]);
 20
 21// Data types for chat completion requests
 22#[derive(Serialize)]
 23struct OpenAIRequest {
 24    model: String,
 25    messages: Vec<RequestMessage>,
 26    stream: bool,
 27}
 28
 29#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 30struct RequestMessage {
 31    role: Role,
 32    content: String,
 33}
 34
 35#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 36struct ResponseMessage {
 37    role: Option<Role>,
 38    content: Option<String>,
 39}
 40
 41#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 42#[serde(rename_all = "lowercase")]
 43enum Role {
 44    User,
 45    Assistant,
 46    System,
 47}
 48
 49#[derive(Deserialize, Debug)]
 50struct OpenAIResponseStreamEvent {
 51    pub id: Option<String>,
 52    pub object: String,
 53    pub created: u32,
 54    pub model: String,
 55    pub choices: Vec<ChatChoiceDelta>,
 56    pub usage: Option<Usage>,
 57}
 58
 59#[derive(Deserialize, Debug)]
 60struct Usage {
 61    pub prompt_tokens: u32,
 62    pub completion_tokens: u32,
 63    pub total_tokens: u32,
 64}
 65
 66#[derive(Deserialize, Debug)]
 67struct ChatChoiceDelta {
 68    pub index: u32,
 69    pub delta: ResponseMessage,
 70    pub finish_reason: Option<String>,
 71}
 72
 73#[derive(Deserialize, Debug)]
 74struct OpenAIUsage {
 75    prompt_tokens: u64,
 76    completion_tokens: u64,
 77    total_tokens: u64,
 78}
 79
 80#[derive(Deserialize, Debug)]
 81struct OpenAIChoice {
 82    text: String,
 83    index: u32,
 84    logprobs: Option<serde_json::Value>,
 85    finish_reason: Option<String>,
 86}
 87
 88pub fn init(cx: &mut AppContext) {
 89    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
 90        return;
 91    }
 92
 93    let assistant = Rc::new(Assistant::default());
 94    cx.add_action({
 95        let assistant = assistant.clone();
 96        move |editor: &mut Editor, _: &Assist, cx: &mut ViewContext<Editor>| {
 97            assistant.assist(editor, cx).log_err();
 98        }
 99    });
100    cx.capture_action({
101        let assistant = assistant.clone();
102        move |_: &mut Editor, _: &editor::Cancel, cx: &mut ViewContext<Editor>| {
103            if !assistant.cancel_last_assist(cx.view_id()) {
104                cx.propagate_action();
105            }
106        }
107    });
108}
109
110type CompletionId = usize;
111
112#[derive(Default)]
113struct Assistant(RefCell<AssistantState>);
114
115#[derive(Default)]
116struct AssistantState {
117    assist_stacks: HashMap<usize, Vec<(CompletionId, Task<Option<()>>)>>,
118    next_completion_id: CompletionId,
119}
120
121impl Assistant {
122    fn assist(self: &Rc<Self>, editor: &mut Editor, cx: &mut ViewContext<Editor>) -> Result<()> {
123        let api_key = std::env::var("OPENAI_API_KEY")?;
124
125        let selections = editor.selections.all(cx);
126        let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
127            // Insert markers around selected text as described in the system prompt above.
128            let snapshot = buffer.snapshot(cx);
129            let mut user_message = String::new();
130            let mut user_message_suffix = String::new();
131            let mut buffer_offset = 0;
132            for selection in selections {
133                if !selection.is_empty() {
134                    if user_message_suffix.is_empty() {
135                        user_message_suffix.push_str("\n\n");
136                    }
137                    user_message_suffix.push_str("[Selected excerpt from above]\n");
138                    user_message_suffix
139                        .extend(snapshot.text_for_range(selection.start..selection.end));
140                    user_message_suffix.push_str("\n\n");
141                }
142
143                user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
144                user_message.push_str("[SELECTION_START]");
145                user_message.extend(snapshot.text_for_range(selection.start..selection.end));
146                buffer_offset = selection.end;
147                user_message.push_str("[SELECTION_END]");
148            }
149            if buffer_offset < snapshot.len() {
150                user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
151            }
152            user_message.push_str(&user_message_suffix);
153
154            // Ensure the document ends with 4 trailing newlines.
155            let trailing_newline_count = snapshot
156                .reversed_chars_at(snapshot.len())
157                .take_while(|c| *c == '\n')
158                .take(4);
159            let buffer_suffix = "\n".repeat(4 - trailing_newline_count.count());
160            buffer.edit([(snapshot.len()..snapshot.len(), buffer_suffix)], None, cx);
161
162            let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
163            let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
164
165            (user_message, insertion_site)
166        });
167
168        let this = self.clone();
169        let buffer = editor.buffer().clone();
170        let executor = cx.background_executor().clone();
171        let editor_id = cx.view_id();
172        let assist_id = util::post_inc(&mut self.0.borrow_mut().next_completion_id);
173        let assist_task = cx.spawn(|_, mut cx| {
174            async move {
175                // TODO: We should have a get_string method on assets. This is repateated elsewhere.
176                let content = Assets::get("contexts/system.zmd").unwrap();
177                let mut system_message = std::str::from_utf8(content.data.as_ref())
178                    .unwrap()
179                    .to_string();
180
181                if let Ok(custom_system_message_path) =
182                    std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH")
183                {
184                    system_message.push_str(
185                        "\n\nAlso consider the following user-defined system prompt:\n\n",
186                    );
187                    // TODO: Replace this with our file system trait object.
188                    system_message.push_str(
189                        &cx.background()
190                            .spawn(async move { fs::read_to_string(custom_system_message_path) })
191                            .await?,
192                    );
193                }
194
195                let stream = stream_completion(
196                    api_key,
197                    executor,
198                    OpenAIRequest {
199                        model: "gpt-4".to_string(),
200                        messages: vec![
201                            RequestMessage {
202                                role: Role::System,
203                                content: system_message.to_string(),
204                            },
205                            RequestMessage {
206                                role: Role::User,
207                                content: user_message,
208                            },
209                        ],
210                        stream: false,
211                    },
212                );
213
214                let mut messages = stream.await?;
215                while let Some(message) = messages.next().await {
216                    let mut message = message?;
217                    if let Some(choice) = message.choices.pop() {
218                        buffer.update(&mut cx, |buffer, cx| {
219                            let text: Arc<str> = choice.delta.content?.into();
220                            buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
221                            Some(())
222                        });
223                    }
224                }
225
226                this.0
227                    .borrow_mut()
228                    .assist_stacks
229                    .get_mut(&editor_id)
230                    .unwrap()
231                    .retain(|(id, _)| *id != assist_id);
232
233                anyhow::Ok(())
234            }
235            .log_err()
236        });
237
238        self.0
239            .borrow_mut()
240            .assist_stacks
241            .entry(cx.view_id())
242            .or_default()
243            .push((assist_id, assist_task));
244
245        Ok(())
246    }
247
248    fn cancel_last_assist(self: &Rc<Self>, editor_id: usize) -> bool {
249        self.0
250            .borrow_mut()
251            .assist_stacks
252            .get_mut(&editor_id)
253            .and_then(|assists| assists.pop())
254            .is_some()
255    }
256}
257
258async fn stream_completion(
259    api_key: String,
260    executor: Arc<Background>,
261    mut request: OpenAIRequest,
262) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
263    request.stream = true;
264
265    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
266
267    let json_data = serde_json::to_string(&request)?;
268    let mut response = Request::post("https://api.openai.com/v1/chat/completions")
269        .header("Content-Type", "application/json")
270        .header("Authorization", format!("Bearer {}", api_key))
271        .body(json_data)?
272        .send_async()
273        .await?;
274
275    let status = response.status();
276    if status == StatusCode::OK {
277        executor
278            .spawn(async move {
279                let mut lines = BufReader::new(response.body_mut()).lines();
280
281                fn parse_line(
282                    line: Result<String, io::Error>,
283                ) -> Result<Option<OpenAIResponseStreamEvent>> {
284                    if let Some(data) = line?.strip_prefix("data: ") {
285                        let event = serde_json::from_str(&data)?;
286                        Ok(Some(event))
287                    } else {
288                        Ok(None)
289                    }
290                }
291
292                while let Some(line) = lines.next().await {
293                    if let Some(event) = parse_line(line).transpose() {
294                        tx.unbounded_send(event).log_err();
295                    }
296                }
297
298                anyhow::Ok(())
299            })
300            .detach();
301
302        Ok(rx)
303    } else {
304        let mut body = String::new();
305        response.body_mut().read_to_string(&mut body).await?;
306
307        Err(anyhow!(
308            "Failed to connect to OpenAI API: {} {}",
309            response.status(),
310            body,
311        ))
312    }
313}