1use crate::{
2 stream_completion,
3 streaming_diff::{Hunk, StreamingDiff},
4 OpenAIRequest,
5};
6use anyhow::Result;
7use editor::{multi_buffer, Anchor, MultiBuffer, ToOffset, ToPoint};
8use futures::{
9 channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, SinkExt, Stream, StreamExt,
10};
11use gpui::{executor::Background, Entity, ModelContext, ModelHandle, Task};
12use language::{IndentSize, Point, Rope, TransactionId};
13use std::{cmp, future, ops::Range, sync::Arc};
14
15pub trait CompletionProvider {
16 fn complete(
17 &self,
18 prompt: OpenAIRequest,
19 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>>;
20}
21
22pub struct OpenAICompletionProvider {
23 api_key: String,
24 executor: Arc<Background>,
25}
26
27impl OpenAICompletionProvider {
28 pub fn new(api_key: String, executor: Arc<Background>) -> Self {
29 Self { api_key, executor }
30 }
31}
32
33impl CompletionProvider for OpenAICompletionProvider {
34 fn complete(
35 &self,
36 prompt: OpenAIRequest,
37 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
38 let request = stream_completion(self.api_key.clone(), self.executor.clone(), prompt);
39 async move {
40 let response = request.await?;
41 let stream = response
42 .filter_map(|response| async move {
43 match response {
44 Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
45 Err(error) => Some(Err(error)),
46 }
47 })
48 .boxed();
49 Ok(stream)
50 }
51 .boxed()
52 }
53}
54
55pub enum Event {
56 Finished,
57 Undone,
58}
59
60pub struct Codegen {
61 provider: Arc<dyn CompletionProvider>,
62 buffer: ModelHandle<MultiBuffer>,
63 range: Range<Anchor>,
64 last_equal_ranges: Vec<Range<Anchor>>,
65 transaction_id: Option<TransactionId>,
66 error: Option<anyhow::Error>,
67 generation: Task<()>,
68 idle: bool,
69 _subscription: gpui::Subscription,
70}
71
72impl Entity for Codegen {
73 type Event = Event;
74}
75
76impl Codegen {
77 pub fn new(
78 buffer: ModelHandle<MultiBuffer>,
79 range: Range<Anchor>,
80 provider: Arc<dyn CompletionProvider>,
81 cx: &mut ModelContext<Self>,
82 ) -> Self {
83 Self {
84 provider,
85 buffer: buffer.clone(),
86 range,
87 last_equal_ranges: Default::default(),
88 transaction_id: Default::default(),
89 error: Default::default(),
90 idle: true,
91 generation: Task::ready(()),
92 _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
93 }
94 }
95
96 fn handle_buffer_event(
97 &mut self,
98 _buffer: ModelHandle<MultiBuffer>,
99 event: &multi_buffer::Event,
100 cx: &mut ModelContext<Self>,
101 ) {
102 if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
103 if self.transaction_id == Some(*transaction_id) {
104 self.transaction_id = None;
105 self.generation = Task::ready(());
106 cx.emit(Event::Undone);
107 }
108 }
109 }
110
111 pub fn range(&self) -> Range<Anchor> {
112 self.range.clone()
113 }
114
115 pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
116 &self.last_equal_ranges
117 }
118
119 pub fn idle(&self) -> bool {
120 self.idle
121 }
122
123 pub fn error(&self) -> Option<&anyhow::Error> {
124 self.error.as_ref()
125 }
126
127 pub fn start(&mut self, prompt: OpenAIRequest, cx: &mut ModelContext<Self>) {
128 let range = self.range.clone();
129 let snapshot = self.buffer.read(cx).snapshot(cx);
130 let selected_text = snapshot
131 .text_for_range(range.start..range.end)
132 .collect::<Rope>();
133
134 let selection_start = range.start.to_point(&snapshot);
135 let selection_end = range.end.to_point(&snapshot);
136
137 let mut base_indent: Option<IndentSize> = None;
138 let mut start_row = selection_start.row;
139 if snapshot.is_line_blank(start_row) {
140 if let Some(prev_non_blank_row) = snapshot.prev_non_blank_row(start_row) {
141 start_row = prev_non_blank_row;
142 }
143 }
144 for row in start_row..=selection_end.row {
145 if snapshot.is_line_blank(row) {
146 continue;
147 }
148
149 let line_indent = snapshot.indent_size_for_line(row);
150 if let Some(base_indent) = base_indent.as_mut() {
151 if line_indent.len < base_indent.len {
152 *base_indent = line_indent;
153 }
154 } else {
155 base_indent = Some(line_indent);
156 }
157 }
158
159 let mut normalized_selected_text = selected_text.clone();
160 if let Some(base_indent) = base_indent {
161 for row in selection_start.row..=selection_end.row {
162 let selection_row = row - selection_start.row;
163 let line_start =
164 normalized_selected_text.point_to_offset(Point::new(selection_row, 0));
165 let indent_len = if row == selection_start.row {
166 base_indent.len.saturating_sub(selection_start.column)
167 } else {
168 let line_len = normalized_selected_text.line_len(selection_row);
169 cmp::min(line_len, base_indent.len)
170 };
171 let indent_end = cmp::min(
172 line_start + indent_len as usize,
173 normalized_selected_text.len(),
174 );
175 normalized_selected_text.replace(line_start..indent_end, "");
176 }
177 }
178
179 let response = self.provider.complete(prompt);
180 self.generation = cx.spawn_weak(|this, mut cx| {
181 async move {
182 let generate = async {
183 let mut edit_start = range.start.to_offset(&snapshot);
184
185 let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
186 let diff = cx.background().spawn(async move {
187 let chunks = strip_markdown_codeblock(response.await?);
188 futures::pin_mut!(chunks);
189 let mut diff = StreamingDiff::new(selected_text.to_string());
190
191 let mut indent_len;
192 let indent_text;
193 if let Some(base_indent) = base_indent {
194 indent_len = base_indent.len;
195 indent_text = match base_indent.kind {
196 language::IndentKind::Space => " ",
197 language::IndentKind::Tab => "\t",
198 };
199 } else {
200 indent_len = 0;
201 indent_text = "";
202 };
203
204 let mut first_line_len = 0;
205 let mut first_line_non_whitespace_char_ix = None;
206 let mut first_line = true;
207 let mut new_text = String::new();
208
209 while let Some(chunk) = chunks.next().await {
210 let chunk = chunk?;
211
212 let mut lines = chunk.split('\n');
213 if let Some(mut line) = lines.next() {
214 if first_line {
215 if first_line_non_whitespace_char_ix.is_none() {
216 if let Some(mut char_ix) =
217 line.find(|ch: char| !ch.is_whitespace())
218 {
219 line = &line[char_ix..];
220 char_ix += first_line_len;
221 first_line_non_whitespace_char_ix = Some(char_ix);
222 let first_line_indent = char_ix
223 .saturating_sub(selection_start.column as usize)
224 as usize;
225 new_text
226 .push_str(&indent_text.repeat(first_line_indent));
227 indent_len = indent_len.saturating_sub(char_ix as u32);
228 }
229 }
230 first_line_len += line.len();
231 }
232
233 if first_line_non_whitespace_char_ix.is_some() {
234 new_text.push_str(line);
235 }
236 }
237
238 for line in lines {
239 first_line = false;
240 new_text.push('\n');
241 if !line.is_empty() {
242 new_text.push_str(&indent_text.repeat(indent_len as usize));
243 }
244 new_text.push_str(line);
245 }
246
247 let hunks = diff.push_new(&new_text);
248 hunks_tx.send(hunks).await?;
249 new_text.clear();
250 }
251 hunks_tx.send(diff.finish()).await?;
252
253 anyhow::Ok(())
254 });
255
256 while let Some(hunks) = hunks_rx.next().await {
257 let this = if let Some(this) = this.upgrade(&cx) {
258 this
259 } else {
260 break;
261 };
262
263 this.update(&mut cx, |this, cx| {
264 this.last_equal_ranges.clear();
265
266 let transaction = this.buffer.update(cx, |buffer, cx| {
267 // Avoid grouping assistant edits with user edits.
268 buffer.finalize_last_transaction(cx);
269
270 buffer.start_transaction(cx);
271 buffer.edit(
272 hunks.into_iter().filter_map(|hunk| match hunk {
273 Hunk::Insert { text } => {
274 let edit_start = snapshot.anchor_after(edit_start);
275 Some((edit_start..edit_start, text))
276 }
277 Hunk::Remove { len } => {
278 let edit_end = edit_start + len;
279 let edit_range = snapshot.anchor_after(edit_start)
280 ..snapshot.anchor_before(edit_end);
281 edit_start = edit_end;
282 Some((edit_range, String::new()))
283 }
284 Hunk::Keep { len } => {
285 let edit_end = edit_start + len;
286 let edit_range = snapshot.anchor_after(edit_start)
287 ..snapshot.anchor_before(edit_end);
288 edit_start += len;
289 this.last_equal_ranges.push(edit_range);
290 None
291 }
292 }),
293 None,
294 cx,
295 );
296
297 buffer.end_transaction(cx)
298 });
299
300 if let Some(transaction) = transaction {
301 if let Some(first_transaction) = this.transaction_id {
302 // Group all assistant edits into the first transaction.
303 this.buffer.update(cx, |buffer, cx| {
304 buffer.merge_transactions(
305 transaction,
306 first_transaction,
307 cx,
308 )
309 });
310 } else {
311 this.transaction_id = Some(transaction);
312 this.buffer.update(cx, |buffer, cx| {
313 buffer.finalize_last_transaction(cx)
314 });
315 }
316 }
317
318 cx.notify();
319 });
320 }
321
322 diff.await?;
323 anyhow::Ok(())
324 };
325
326 let result = generate.await;
327 if let Some(this) = this.upgrade(&cx) {
328 this.update(&mut cx, |this, cx| {
329 this.last_equal_ranges.clear();
330 this.idle = true;
331 if let Err(error) = result {
332 this.error = Some(error);
333 }
334 cx.emit(Event::Finished);
335 cx.notify();
336 });
337 }
338 }
339 });
340 self.error.take();
341 self.idle = false;
342 cx.notify();
343 }
344
345 pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
346 if let Some(transaction_id) = self.transaction_id {
347 self.buffer
348 .update(cx, |buffer, cx| buffer.undo_transaction(transaction_id, cx));
349 }
350 }
351}
352
353fn strip_markdown_codeblock(
354 stream: impl Stream<Item = Result<String>>,
355) -> impl Stream<Item = Result<String>> {
356 let mut first_line = true;
357 let mut buffer = String::new();
358 let mut starts_with_fenced_code_block = false;
359 stream.filter_map(move |chunk| {
360 let chunk = match chunk {
361 Ok(chunk) => chunk,
362 Err(err) => return future::ready(Some(Err(err))),
363 };
364 buffer.push_str(&chunk);
365
366 if first_line {
367 if buffer == "" || buffer == "`" || buffer == "``" {
368 return future::ready(None);
369 } else if buffer.starts_with("```") {
370 starts_with_fenced_code_block = true;
371 if let Some(newline_ix) = buffer.find('\n') {
372 buffer.replace_range(..newline_ix + 1, "");
373 first_line = false;
374 } else {
375 return future::ready(None);
376 }
377 }
378 }
379
380 let text = if starts_with_fenced_code_block {
381 buffer
382 .strip_suffix("\n```\n")
383 .or_else(|| buffer.strip_suffix("\n```"))
384 .or_else(|| buffer.strip_suffix("\n``"))
385 .or_else(|| buffer.strip_suffix("\n`"))
386 .or_else(|| buffer.strip_suffix('\n'))
387 .unwrap_or(&buffer)
388 } else {
389 &buffer
390 };
391
392 if text.contains('\n') {
393 first_line = false;
394 }
395
396 let remainder = buffer.split_off(text.len());
397 let result = if buffer.is_empty() {
398 None
399 } else {
400 Some(Ok(buffer.clone()))
401 };
402 buffer = remainder;
403 future::ready(result)
404 })
405}
406
407#[cfg(test)]
408mod tests {
409 use futures::stream;
410
411 use super::*;
412
413 #[gpui::test]
414 async fn test_strip_markdown_codeblock() {
415 assert_eq!(
416 strip_markdown_codeblock(chunks("Lorem ipsum dolor", 2))
417 .map(|chunk| chunk.unwrap())
418 .collect::<String>()
419 .await,
420 "Lorem ipsum dolor"
421 );
422 assert_eq!(
423 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor", 2))
424 .map(|chunk| chunk.unwrap())
425 .collect::<String>()
426 .await,
427 "Lorem ipsum dolor"
428 );
429 assert_eq!(
430 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```", 2))
431 .map(|chunk| chunk.unwrap())
432 .collect::<String>()
433 .await,
434 "Lorem ipsum dolor"
435 );
436 assert_eq!(
437 strip_markdown_codeblock(chunks("```\nLorem ipsum dolor\n```\n", 2))
438 .map(|chunk| chunk.unwrap())
439 .collect::<String>()
440 .await,
441 "Lorem ipsum dolor"
442 );
443 assert_eq!(
444 strip_markdown_codeblock(chunks("```html\n```js\nLorem ipsum dolor\n```\n```", 2))
445 .map(|chunk| chunk.unwrap())
446 .collect::<String>()
447 .await,
448 "```js\nLorem ipsum dolor\n```"
449 );
450 assert_eq!(
451 strip_markdown_codeblock(chunks("``\nLorem ipsum dolor\n```", 2))
452 .map(|chunk| chunk.unwrap())
453 .collect::<String>()
454 .await,
455 "``\nLorem ipsum dolor\n```"
456 );
457
458 fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
459 stream::iter(
460 text.chars()
461 .collect::<Vec<_>>()
462 .chunks(size)
463 .map(|chunk| Ok(chunk.iter().collect::<String>()))
464 .collect::<Vec<_>>(),
465 )
466 }
467 }
468}