1use anyhow::{anyhow, Result};
2use assets::Assets;
3use collections::HashMap;
4use editor::Editor;
5use futures::AsyncBufReadExt;
6use futures::{io::BufReader, AsyncReadExt, Stream, StreamExt};
7use gpui::executor::Background;
8use gpui::{actions, AppContext, Task, ViewContext};
9use isahc::prelude::*;
10use isahc::{http::StatusCode, Request};
11use serde::{Deserialize, Serialize};
12use std::cell::RefCell;
13use std::fs;
14use std::rc::Rc;
15use std::{io, sync::Arc};
16use util::channel::{ReleaseChannel, RELEASE_CHANNEL};
17use util::{ResultExt, TryFutureExt};
18
19actions!(ai, [Assist]);
20
21// Data types for chat completion requests
22#[derive(Serialize)]
23struct OpenAIRequest {
24 model: String,
25 messages: Vec<RequestMessage>,
26 stream: bool,
27}
28
29#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
30struct RequestMessage {
31 role: Role,
32 content: String,
33}
34
35#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
36struct ResponseMessage {
37 role: Option<Role>,
38 content: Option<String>,
39}
40
41#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
42#[serde(rename_all = "lowercase")]
43enum Role {
44 User,
45 Assistant,
46 System,
47}
48
49#[derive(Deserialize, Debug)]
50struct OpenAIResponseStreamEvent {
51 pub id: Option<String>,
52 pub object: String,
53 pub created: u32,
54 pub model: String,
55 pub choices: Vec<ChatChoiceDelta>,
56 pub usage: Option<Usage>,
57}
58
59#[derive(Deserialize, Debug)]
60struct Usage {
61 pub prompt_tokens: u32,
62 pub completion_tokens: u32,
63 pub total_tokens: u32,
64}
65
66#[derive(Deserialize, Debug)]
67struct ChatChoiceDelta {
68 pub index: u32,
69 pub delta: ResponseMessage,
70 pub finish_reason: Option<String>,
71}
72
73#[derive(Deserialize, Debug)]
74struct OpenAIUsage {
75 prompt_tokens: u64,
76 completion_tokens: u64,
77 total_tokens: u64,
78}
79
80#[derive(Deserialize, Debug)]
81struct OpenAIChoice {
82 text: String,
83 index: u32,
84 logprobs: Option<serde_json::Value>,
85 finish_reason: Option<String>,
86}
87
88pub fn init(cx: &mut AppContext) {
89 if *RELEASE_CHANNEL == ReleaseChannel::Stable {
90 return;
91 }
92
93 let assistant = Rc::new(Assistant::default());
94 cx.add_action({
95 let assistant = assistant.clone();
96 move |editor: &mut Editor, _: &Assist, cx: &mut ViewContext<Editor>| {
97 assistant.assist(editor, cx).log_err();
98 }
99 });
100 cx.capture_action({
101 let assistant = assistant.clone();
102 move |_: &mut Editor, _: &editor::Cancel, cx: &mut ViewContext<Editor>| {
103 if !assistant.cancel_last_assist(cx.view_id()) {
104 cx.propagate_action();
105 }
106 }
107 });
108}
109
110type CompletionId = usize;
111
112#[derive(Default)]
113struct Assistant(RefCell<AssistantState>);
114
115#[derive(Default)]
116struct AssistantState {
117 assist_stacks: HashMap<usize, Vec<(CompletionId, Task<Option<()>>)>>,
118 next_completion_id: CompletionId,
119}
120
121impl Assistant {
122 fn assist(self: &Rc<Self>, editor: &mut Editor, cx: &mut ViewContext<Editor>) -> Result<()> {
123 let api_key = std::env::var("OPENAI_API_KEY")?;
124
125 let selections = editor.selections.all(cx);
126 let (user_message, insertion_site) = editor.buffer().update(cx, |buffer, cx| {
127 // Insert markers around selected text as described in the system prompt above.
128 let snapshot = buffer.snapshot(cx);
129 let mut user_message = String::new();
130 let mut user_message_suffix = String::new();
131 let mut buffer_offset = 0;
132 for selection in selections {
133 if !selection.is_empty() {
134 if user_message_suffix.is_empty() {
135 user_message_suffix.push_str("\n\n");
136 }
137 user_message_suffix.push_str("[Selected excerpt from above]\n");
138 user_message_suffix
139 .extend(snapshot.text_for_range(selection.start..selection.end));
140 user_message_suffix.push_str("\n\n");
141 }
142
143 user_message.extend(snapshot.text_for_range(buffer_offset..selection.start));
144 user_message.push_str("[SELECTION_START]");
145 user_message.extend(snapshot.text_for_range(selection.start..selection.end));
146 buffer_offset = selection.end;
147 user_message.push_str("[SELECTION_END]");
148 }
149 if buffer_offset < snapshot.len() {
150 user_message.extend(snapshot.text_for_range(buffer_offset..snapshot.len()));
151 }
152 user_message.push_str(&user_message_suffix);
153
154 // Ensure the document ends with 4 trailing newlines.
155 let trailing_newline_count = snapshot
156 .reversed_chars_at(snapshot.len())
157 .take_while(|c| *c == '\n')
158 .take(4);
159 let buffer_suffix = "\n".repeat(4 - trailing_newline_count.count());
160 buffer.edit([(snapshot.len()..snapshot.len(), buffer_suffix)], None, cx);
161
162 let snapshot = buffer.snapshot(cx); // Take a new snapshot after editing.
163 let insertion_site = snapshot.anchor_after(snapshot.len() - 2);
164
165 (user_message, insertion_site)
166 });
167
168 let this = self.clone();
169 let buffer = editor.buffer().clone();
170 let executor = cx.background_executor().clone();
171 let editor_id = cx.view_id();
172 let assist_id = util::post_inc(&mut self.0.borrow_mut().next_completion_id);
173 let assist_task = cx.spawn(|_, mut cx| {
174 async move {
175 // TODO: We should have a get_string method on assets. This is repateated elsewhere.
176 let content = Assets::get("contexts/system.zmd").unwrap();
177 let mut system_message = std::str::from_utf8(content.data.as_ref())
178 .unwrap()
179 .to_string();
180
181 if let Ok(custom_system_message_path) =
182 std::env::var("ZED_ASSISTANT_SYSTEM_PROMPT_PATH")
183 {
184 system_message.push_str(
185 "\n\nAlso consider the following user-defined system prompt:\n\n",
186 );
187 // TODO: Replace this with our file system trait object.
188 system_message.push_str(
189 &cx.background()
190 .spawn(async move { fs::read_to_string(custom_system_message_path) })
191 .await?,
192 );
193 }
194
195 let stream = stream_completion(
196 api_key,
197 executor,
198 OpenAIRequest {
199 model: "gpt-4".to_string(),
200 messages: vec![
201 RequestMessage {
202 role: Role::System,
203 content: system_message.to_string(),
204 },
205 RequestMessage {
206 role: Role::User,
207 content: user_message,
208 },
209 ],
210 stream: false,
211 },
212 );
213
214 let mut messages = stream.await?;
215 while let Some(message) = messages.next().await {
216 let mut message = message?;
217 if let Some(choice) = message.choices.pop() {
218 buffer.update(&mut cx, |buffer, cx| {
219 let text: Arc<str> = choice.delta.content?.into();
220 buffer.edit([(insertion_site.clone()..insertion_site, text)], None, cx);
221 Some(())
222 });
223 }
224 }
225
226 this.0
227 .borrow_mut()
228 .assist_stacks
229 .get_mut(&editor_id)
230 .unwrap()
231 .retain(|(id, _)| *id != assist_id);
232
233 anyhow::Ok(())
234 }
235 .log_err()
236 });
237
238 self.0
239 .borrow_mut()
240 .assist_stacks
241 .entry(cx.view_id())
242 .or_default()
243 .push((assist_id, assist_task));
244
245 Ok(())
246 }
247
248 fn cancel_last_assist(self: &Rc<Self>, editor_id: usize) -> bool {
249 self.0
250 .borrow_mut()
251 .assist_stacks
252 .get_mut(&editor_id)
253 .and_then(|assists| assists.pop())
254 .is_some()
255 }
256}
257
258async fn stream_completion(
259 api_key: String,
260 executor: Arc<Background>,
261 mut request: OpenAIRequest,
262) -> Result<impl Stream<Item = Result<OpenAIResponseStreamEvent>>> {
263 request.stream = true;
264
265 let (tx, rx) = futures::channel::mpsc::unbounded::<Result<OpenAIResponseStreamEvent>>();
266
267 let json_data = serde_json::to_string(&request)?;
268 let mut response = Request::post("https://api.openai.com/v1/chat/completions")
269 .header("Content-Type", "application/json")
270 .header("Authorization", format!("Bearer {}", api_key))
271 .body(json_data)?
272 .send_async()
273 .await?;
274
275 let status = response.status();
276 if status == StatusCode::OK {
277 executor
278 .spawn(async move {
279 let mut lines = BufReader::new(response.body_mut()).lines();
280
281 fn parse_line(
282 line: Result<String, io::Error>,
283 ) -> Result<Option<OpenAIResponseStreamEvent>> {
284 if let Some(data) = line?.strip_prefix("data: ") {
285 let event = serde_json::from_str(&data)?;
286 Ok(Some(event))
287 } else {
288 Ok(None)
289 }
290 }
291
292 while let Some(line) = lines.next().await {
293 if let Some(event) = parse_line(line).transpose() {
294 tx.unbounded_send(event).log_err();
295 }
296 }
297
298 anyhow::Ok(())
299 })
300 .detach();
301
302 Ok(rx)
303 } else {
304 let mut body = String::new();
305 response.body_mut().read_to_string(&mut body).await?;
306
307 Err(anyhow!(
308 "Failed to connect to OpenAI API: {} {}",
309 response.status(),
310 body,
311 ))
312 }
313}