ai.rs

  1use anyhow::{anyhow, Result};
  2use collections::HashMap;
  3use editor::Editor;
  4use futures::AsyncBufReadExt;
  5use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
  6use gpui::executor::Background;
  7use gpui::{actions, AppContext, Task, ViewContext};
  8use isahc::prelude::*;
  9use isahc::{http::StatusCode, Request};
 10use serde::{Deserialize, Serialize};
 11use std::cell::RefCell;
 12use std::fs;
 13use std::rc::Rc;
 14use std::{io, sync::Arc};
 15use util::channel::{ReleaseChannel, RELEASE_CHANNEL};
 16use util::{ResultExt, TryFutureExt};
 17
 18use rust_embed::RustEmbed;
 19use std::str;
 20
 21#[derive(RustEmbed)]
 22#[folder = "../../assets/contexts"]
 23#[exclude = "*.DS_Store"]
 24pub struct ContextAssets;
 25
 26actions!(ai, [Assist]);
 27
 28// Data types for chat completion requests
 29#[derive(Serialize)]
 30struct OpenAIRequest {
 31    model: String,
 32    messages: Vec<RequestMessage>,
 33    stream: bool,
 34}
 35
 36#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 37struct RequestMessage {
 38    role: Role,
 39    content: String,
 40}
 41
 42#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 43struct ResponseMessage {
 44    role: Option<Role>,
 45    content: Option<String>,
 46}
 47
 48#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
 49#[serde(rename_all = "lowercase")]
 50enum Role {
 51    User,
 52    Assistant,
 53    System,
 54}
 55
 56#[derive(Deserialize, Debug)]
 57struct OpenAIResponseStreamEvent {
 58    pub id: Option<String>,
 59    pub object: String,
 60    pub created: u32,
 61    pub model: String,
 62    pub choices: Vec<ChatChoiceDelta>,
 63    pub usage: Option<Usage>,
 64}
 65
 66#[derive(Deserialize, Debug)]
 67struct Usage {
 68    pub prompt_tokens: u32,
 69    pub completion_tokens: u32,
 70    pub total_tokens: u32,
 71}
 72
 73#[derive(Deserialize, Debug)]
 74struct ChatChoiceDelta {
 75    pub index: u32,
 76    pub delta: ResponseMessage,
 77    pub finish_reason: Option<String>,
 78}
 79
 80#[derive(Deserialize, Debug)]
 81struct OpenAIUsage {
 82    prompt_tokens: u64,
 83    completion_tokens: u64,
 84    total_tokens: u64,
 85}
 86
 87#[derive(Deserialize, Debug)]
 88struct OpenAIChoice {
 89    text: String,
 90    index: u32,
 91    logprobs: Option<serde_json::Value>,
 92    finish_reason: Option<String>,
 93}
 94
 95pub fn init(cx: &mut AppContext) {
 96    if *RELEASE_CHANNEL == ReleaseChannel::Stable {
 97        return;
 98    }
 99
100    let assistant = Rc::new(Assistant::default());
101    cx.add_action({
102        let assistant = assistant.clone();
103        move |editor: &mut Editor, _: &Assist, cx: &mut ViewContext<Editor>| {
104            assistant.assist(editor, cx).log_err();
105        }
106    });
107    cx.capture_action({
108        let assistant = assistant.clone();
109        move |_: &mut Editor, _: &editor::Cancel, cx: &mut ViewContext<Editor>| {
110            if !assistant.cancel_last_assist(cx.view_id()) {
111                cx.propagate_action();
112            }
113        }
114    });
115}
116
117type CompletionId = usize;
118
119#[derive(Default)]
120struct Assistant(RefCell<AssistantState>);
121
122#[derive(Default)]
123struct AssistantState {
124    assist_stacks: HashMap<usize, Vec<(CompletionId, Task<Option<()>>)>>,
125    next_completion_id: CompletionId,
126}
127
128impl Assistant {
129    fn assist(self: &Rc<Self>, editor: &mut Editor, cx: &mut ViewContext<Editor>) -> Result<()> {
130        let api_key = std::env::var("OPENAI_API_KEY")?;
131
132        let selections = editor.selections.all(cx);
133        let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
134            // Insert markers around selected text as described in the system prompt above.
135            let snapshot = buffer.snapshot(cx);
136            let mut user_message = String::new();
137            let mut user_message_suffix = String::new();
138            let mut buffer_offset = 0;
139            for selection in selections {
140                if !selection.is_empty() {
141                    if user_message_suffix.is_empty() {
142                        user_message_suffix.push_str("\n\n");
143                    }
144                    user_message_suffix.push_str("[Selected excerpt from above]\n");
145                    user_message_suffix
146                        .extend(snapshot.text_for_range(selection.start..selection.end));
147                    user_message_suffix.push_str("\n\n");
148                }
149
150                user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
151                user_message.push_str("[SELECTION_START]");
152                user_message.extend(snapshot.text_for_range(selection.start..selection.end));
153                buffer_offset = selection.end;
154                user_message.push_str("[SELECTION_END]");
155            }
156            if buffer_offset < snapshot.len() {
157                user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
158            }
159            user_message.push_str(&user_message_suffix);
160
161            // Ensure the document ends with 4 trailing newlines.
162            let trailing_newline_count = snapshot
163                .reversed_chars_at(snapshot.len())
164                .take_while(|c| *c == '\n')
165                .take(4);
166            let buffer_suffix = "\n".repeat(4 - trailing_newline_count.count());
167            buffer.edit([(snapshot.len()..snapshot.len(), buffer_suffix)], None, cx);
168
169            let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
170            let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
171
172            (user_message, insertion_site)
173        });
174
175        let this = self.clone();
176        let buffer = editor.buffer().clone();
177        let executor = cx.background_executor().clone();
178        let editor_id = cx.view_id();
179        let assist_id = util::post_inc(&mut self.0.borrow_mut().next_completion_id);
180        let assist_task = cx.spawn(|_, mut cx| {
181            async move {
182                // TODO: We should have a get_string method on assets. This is repateated elsewhere.
183                let content = ContextAssets::get("system.zmd").unwrap();
184                let mut system_message = std::str::from_utf8(content.data.as_ref())
185                    .unwrap()
186                    .to_string();
187
188                if let Ok(custom_system_message_path) =
189                    std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH")
190                {
191                    system_message.push_str(
192                        "\n\nAlso consider the following user-defined system prompt:\n\n",
193                    );
194                    // TODO: Replace this with our file system trait object.
195                    system_message.push_str(
196                        &cx.background()
197                            .spawn(async move { fs::read_to_string(custom_system_message_path) })
198                            .await?,
199                    );
200                }
201
202                let stream = stream_completion(
203                    api_key,
204                    executor,
205                    OpenAIRequest {
206                        model: "gpt-4".to_string(),
207                        messages: vec![
208                            RequestMessage {
209                                role: Role::System,
210                                content: system_message.to_string(),
211                            },
212                            RequestMessage {
213                                role: Role::User,
214                                content: user_message,
215                            },
216                        ],
217                        stream: false,
218                    },
219                );
220
221                let mut messages = stream.await?;
222                while let Some(message) = messages.next().await {
223                    let mut message = message?;
224                    if let Some(choice) = message.choices.pop() {
225                        buffer.update(&mut cx, |buffer, cx| {
226                            let text: Arc<str> = choice.delta.content?.into();
227                            buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
228                            Some(())
229                        });
230                    }
231                }
232
233                this.0
234                    .borrow_mut()
235                    .assist_stacks
236                    .get_mut(&editor_id)
237                    .unwrap()
238                    .retain(|(id, _)| *id != assist_id);
239
240                anyhow::Ok(())
241            }
242            .log_err()
243        });
244
245        self.0
246            .borrow_mut()
247            .assist_stacks
248            .entry(cx.view_id())
249            .or_default()
250            .push((assist_id, assist_task));
251
252        Ok(())
253    }
254
255    fn cancel_last_assist(self: &Rc<Self>, editor_id: usize) -> bool {
256        self.0
257            .borrow_mut()
258            .assist_stacks
259            .get_mut(&editor_id)
260            .and_then(|assists| assists.pop())
261            .is_some()
262    }
263}
264
265async fn stream_completion(
266    api_key: String,
267    executor: Arc<Background>,
268    mut request: OpenAIRequest,
269) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
270    request.stream = true;
271
272    let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
273
274    let json_data = serde_json::to_string(&request)?;
275    let mut response = Request::post("https://api.openai.com/v1/chat/completions")
276        .header("Content-Type", "application/json")
277        .header("Authorization", format!("Bearer {}", api_key))
278        .body(json_data)?
279        .send_async()
280        .await?;
281
282    let status = response.status();
283    if status == StatusCode::OK {
284        executor
285            .spawn(async move {
286                let mut lines = BufReader::new(response.body_mut()).lines();
287
288                fn parse_line(
289                    line: Result<String, io::Error>,
290                ) -> Result<Option<OpenAIResponseStreamEvent>> {
291                    if let Some(data) = line?.strip_prefix("data: ") {
292                        let event = serde_json::from_str(&data)?;
293                        Ok(Some(event))
294                    } else {
295                        Ok(None)
296                    }
297                }
298
299                while let Some(line) = lines.next().await {
300                    if let Some(event) = parse_line(line).transpose() {
301                        tx.unbounded_send(event).log_err();
302                    }
303                }
304
305                anyhow::Ok(())
306            })
307            .detach();
308
309        Ok(rx)
310    } else {
311        let mut body = String::new();
312        response.body_mut().read_to_string(&mut body).await?;
313
314        Err(anyhow!(
315            "Failed to connect to OpenAI API: {} {}",
316            response.status(),
317            body,
318        ))
319    }
320}