1use anyhow::{anyhow, Result};
2use editor::Editor;
3use futures::AsyncBufReadExt;
4use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
5use gpui::executor::Background;
6use gpui::{actions, AppContext, Task, ViewContext};
7use indoc::indoc;
8use isahc::prelude::*;
9use isahc::{http::StatusCode, Request};
10use serde::{Deserialize, Serialize};
11use std::{io, sync::Arc};
12use util::ResultExt;
13
14actions!(ai, [Assist]);
15
16// Data types for chat completion requests
17#[derive(Serialize)]
18struct OpenAIRequest {
19 model: String,
20 messages: Vec<RequestMessage>,
21 stream: bool,
22}
23
24#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
25struct RequestMessage {
26 role: Role,
27 content: String,
28}
29
30#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
31struct ResponseMessage {
32 role: Option<Role>,
33 content: Option<String>,
34}
35
36#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
37#[serde(rename_all = "lowercase")]
38enum Role {
39 User,
40 Assistant,
41 System,
42}
43
44#[derive(Deserialize, Debug)]
45struct OpenAIResponseStreamEvent {
46 pub id: Option<String>,
47 pub object: String,
48 pub created: u32,
49 pub model: String,
50 pub choices: Vec<ChatChoiceDelta>,
51 pub usage: Option<Usage>,
52}
53
54#[derive(Deserialize, Debug)]
55struct Usage {
56 pub prompt_tokens: u32,
57 pub completion_tokens: u32,
58 pub total_tokens: u32,
59}
60
61#[derive(Deserialize, Debug)]
62struct ChatChoiceDelta {
63 pub index: u32,
64 pub delta: ResponseMessage,
65 pub finish_reason: Option<String>,
66}
67
68#[derive(Deserialize, Debug)]
69struct OpenAIUsage {
70 prompt_tokens: u64,
71 completion_tokens: u64,
72 total_tokens: u64,
73}
74
75#[derive(Deserialize, Debug)]
76struct OpenAIChoice {
77 text: String,
78 index: u32,
79 logprobs: Option<serde_json::Value>,
80 finish_reason: Option<String>,
81}
82
83pub fn init(cx: &mut AppContext) {
84 cx.add_async_action(assist)
85}
86
87fn assist(
88 editor: &mut Editor,
89 _: &Assist,
90 cx: &mut ViewContext<Editor>,
91) -> Option<Task<Result<()>>> {
92 let api_key = std::env::var("OPENAI_API_KEY").log_err()?;
93
94 const SYSTEM_MESSAGE: &'static str = indoc! {r#"
95 You an AI language model embedded in a code editor named Zed, authored by Zed Industries.
96 The input you are currently processing was produced by a special \"model mention\" in a document that is open in the editor.
97 A model mention is indicated via a leading / on a line.
98 The user's currently selected text is indicated via ->->selected text<-<- surrounding selected text.
99 In this sentence, the word ->->example<-<- is selected.
100 Respond to any selected model mention.
101
102 Wrap your responses in > < as follows.
103 / What do you think?
104 > I think that's a great idea. <
105
106 For lines that are likely to wrap, or multiline responses, start and end the > and < on their own lines.
107 >
108 I think that's a great idea
109 <
110
111 If the selected mention is not at the end of the document, briefly summarize the context.
112 > Key ideas of generative programming:
113 * Managing context
114 * Managing length
115 * Context distillation
116 - Shrink a context's size without loss of meaning.
117 * Fine-grained version control
118 * Portals to other contexts
119 * Distillation policies
120 * Budgets
121 <
122 "#};
123
124 let selections = editor.selections.all(cx);
125 let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
126 // Insert ->-> <-<- around selected text as described in the system prompt above.
127 let snapshot = buffer.snapshot(cx);
128 let mut user_message = String::new();
129 let mut buffer_offset = 0;
130 for selection in selections {
131 user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
132 user_message.push_str("->->");
133 user_message.extend(snapshot.text_for_range(selection.start..selection.end));
134 buffer_offset = selection.end;
135 user_message.push_str("<-<-");
136 }
137 if buffer_offset < snapshot.len() {
138 user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
139 }
140
141 // Ensure the document ends with 4 trailing newlines.
142 let trailing_newline_count = snapshot
143 .reversed_chars_at(snapshot.len())
144 .take_while(|c| *c == '\n')
145 .take(4);
146 let suffix = "\n".repeat(4 - trailing_newline_count.count());
147 buffer.edit([(snapshot.len()..snapshot.len(), suffix)], None, cx);
148
149 let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
150 let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
151
152 (user_message, insertion_site)
153 });
154
155 let stream = stream_completion(
156 api_key,
157 cx.background_executor().clone(),
158 OpenAIRequest {
159 model: "gpt-4".to_string(),
160 messages: vec![
161 RequestMessage {
162 role: Role::System,
163 content: SYSTEM_MESSAGE.to_string(),
164 },
165 RequestMessage {
166 role: Role::User,
167 content: user_message,
168 },
169 ],
170 stream: false,
171 },
172 );
173 let buffer = editor.buffer().clone();
174 Some(cx.spawn(|_, mut cx| async move {
175 let mut messages = stream.await?;
176 while let Some(message) = messages.next().await {
177 let mut message = message?;
178 if let Some(choice) = message.choices.pop() {
179 buffer.update(&mut cx, |buffer, cx| {
180 let text: Arc<str> = choice.delta.content?.into();
181 buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
182 Some(())
183 });
184 }
185 }
186 Ok(())
187 }))
188}
189
190async fn stream_completion(
191 api_key: String,
192 executor: Arc<Background>,
193 mut request: OpenAIRequest,
194) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
195 request.stream = true;
196
197 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
198
199 let json_data = serde_json::to_string(&request)?;
200 let mut response = Request::post("https://api.openai.com/v1/chat/completions")
201 .header("Content-Type", "application/json")
202 .header("Authorization", format!("Bearer {}", api_key))
203 .body(json_data)?
204 .send_async()
205 .await?;
206
207 let status = response.status();
208 if status == StatusCode::OK {
209 executor
210 .spawn(async move {
211 let mut lines = BufReader::new(response.body_mut()).lines();
212
213 fn parse_line(
214 line: Result<String, io::Error>,
215 ) -> Result<Option<OpenAIResponseStreamEvent>> {
216 if let Some(data) = line?.strip_prefix("data: ") {
217 let event = serde_json::from_str(&data)?;
218 Ok(Some(event))
219 } else {
220 Ok(None)
221 }
222 }
223
224 while let Some(line) = lines.next().await {
225 if let Some(event) = parse_line(line).transpose() {
226 tx.unbounded_send(event).log_err();
227 }
228 }
229
230 anyhow::Ok(())
231 })
232 .detach();
233
234 Ok(rx)
235 } else {
236 let mut body = String::new();
237 response.body_mut().read_to_string(&mut body).await?;
238
239 Err(anyhow!(
240 "Failed to connect to OpenAI API: {} {}",
241 response.status(),
242 body,
243 ))
244 }
245}
246
247#[cfg(test)]
248mod tests {}