terminal_inline_assistant.rs

  1use crate::context::attach_context_to_message;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
  7use crate::thread_store::ThreadStore;
  8use anyhow::{Context as _, Result};
  9use client::telemetry::Telemetry;
 10use collections::{HashMap, VecDeque};
 11use editor::{MultiBuffer, actions::SelectAll};
 12use fs::Fs;
 13use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
 14use language::Buffer;
 15use language_model::{
 16    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 17    Role, report_assistant_event,
 18};
 19use prompt_store::PromptBuilder;
 20use std::sync::Arc;
 21use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 22use terminal_view::TerminalView;
 23use ui::prelude::*;
 24use util::ResultExt;
 25use workspace::{Toast, Workspace, notifications::NotificationId};
 26
 27pub fn init(
 28    fs: Arc<dyn Fs>,
 29    prompt_builder: Arc<PromptBuilder>,
 30    telemetry: Arc<Telemetry>,
 31    cx: &mut App,
 32) {
 33    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 34}
 35
 36const DEFAULT_CONTEXT_LINES: usize = 50;
 37const PROMPT_HISTORY_MAX_LEN: usize = 20;
 38
 39pub struct TerminalInlineAssistant {
 40    next_assist_id: TerminalInlineAssistId,
 41    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 42    prompt_history: VecDeque<String>,
 43    telemetry: Option<Arc<Telemetry>>,
 44    fs: Arc<dyn Fs>,
 45    prompt_builder: Arc<PromptBuilder>,
 46}
 47
 48impl Global for TerminalInlineAssistant {}
 49
 50impl TerminalInlineAssistant {
 51    pub fn new(
 52        fs: Arc<dyn Fs>,
 53        prompt_builder: Arc<PromptBuilder>,
 54        telemetry: Arc<Telemetry>,
 55    ) -> Self {
 56        Self {
 57            next_assist_id: TerminalInlineAssistId::default(),
 58            assists: HashMap::default(),
 59            prompt_history: VecDeque::default(),
 60            telemetry: Some(telemetry),
 61            fs,
 62            prompt_builder,
 63        }
 64    }
 65
 66    pub fn assist(
 67        &mut self,
 68        terminal_view: &Entity<TerminalView>,
 69        workspace: Entity<Workspace>,
 70        thread_store: Option<WeakEntity<ThreadStore>>,
 71        window: &mut Window,
 72        cx: &mut App,
 73    ) {
 74        let terminal = terminal_view.read(cx).terminal().clone();
 75        let assist_id = self.next_assist_id.post_inc();
 76        let prompt_buffer =
 77            cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx));
 78        let project = workspace.read(cx).project().downgrade();
 79        let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
 80        let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 81
 82        let prompt_editor = cx.new(|cx| {
 83            PromptEditor::new_terminal(
 84                assist_id,
 85                self.prompt_history.clone(),
 86                prompt_buffer.clone(),
 87                codegen,
 88                self.fs.clone(),
 89                context_store.clone(),
 90                workspace.downgrade(),
 91                thread_store.clone(),
 92                window,
 93                cx,
 94            )
 95        });
 96        let prompt_editor_render = prompt_editor.clone();
 97        let block = terminal_view::BlockProperties {
 98            height: 2,
 99            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
100        };
101        terminal_view.update(cx, |terminal_view, cx| {
102            terminal_view.set_block_below_cursor(block, window, cx);
103        });
104
105        let terminal_assistant = TerminalInlineAssist::new(
106            assist_id,
107            terminal_view,
108            prompt_editor,
109            workspace.downgrade(),
110            context_store,
111            window,
112            cx,
113        );
114
115        self.assists.insert(assist_id, terminal_assistant);
116
117        self.focus_assist(assist_id, window, cx);
118    }
119
120    fn focus_assist(
121        &mut self,
122        assist_id: TerminalInlineAssistId,
123        window: &mut Window,
124        cx: &mut App,
125    ) {
126        let assist = &self.assists[&assist_id];
127        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
128            prompt_editor.update(cx, |this, cx| {
129                this.editor.update(cx, |editor, cx| {
130                    window.focus(&editor.focus_handle(cx));
131                    editor.select_all(&SelectAll, window, cx);
132                });
133            });
134        }
135    }
136
137    fn handle_prompt_editor_event(
138        &mut self,
139        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
140        event: &PromptEditorEvent,
141        window: &mut Window,
142        cx: &mut App,
143    ) {
144        let assist_id = prompt_editor.read(cx).id();
145        match event {
146            PromptEditorEvent::StartRequested => {
147                self.start_assist(assist_id, cx);
148            }
149            PromptEditorEvent::StopRequested => {
150                self.stop_assist(assist_id, cx);
151            }
152            PromptEditorEvent::ConfirmRequested { execute } => {
153                self.finish_assist(assist_id, false, *execute, window, cx);
154            }
155            PromptEditorEvent::CancelRequested => {
156                self.finish_assist(assist_id, true, false, window, cx);
157            }
158            PromptEditorEvent::DismissRequested => {
159                self.dismiss_assist(assist_id, window, cx);
160            }
161            PromptEditorEvent::Resized { height_in_lines } => {
162                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
163            }
164        }
165    }
166
167    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
168        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
169            assist
170        } else {
171            return;
172        };
173
174        let Some(user_prompt) = assist
175            .prompt_editor
176            .as_ref()
177            .map(|editor| editor.read(cx).prompt(cx))
178        else {
179            return;
180        };
181
182        self.prompt_history.retain(|prompt| *prompt != user_prompt);
183        self.prompt_history.push_back(user_prompt.clone());
184        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
185            self.prompt_history.pop_front();
186        }
187
188        assist
189            .terminal
190            .update(cx, |terminal, cx| {
191                terminal
192                    .terminal()
193                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
194            })
195            .log_err();
196
197        let codegen = assist.codegen.clone();
198        let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
199            return;
200        };
201
202        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
203    }
204
205    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
206        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
207            assist
208        } else {
209            return;
210        };
211
212        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
213    }
214
215    fn request_for_inline_assist(
216        &self,
217        assist_id: TerminalInlineAssistId,
218        cx: &mut App,
219    ) -> Result<LanguageModelRequest> {
220        let assist = self.assists.get(&assist_id).context("invalid assist")?;
221
222        let shell = std::env::var("SHELL").ok();
223        let (latest_output, working_directory) = assist
224            .terminal
225            .update(cx, |terminal, cx| {
226                let terminal = terminal.entity().read(cx);
227                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
228                let working_directory = terminal
229                    .working_directory()
230                    .map(|path| path.to_string_lossy().to_string());
231                (latest_output, working_directory)
232            })
233            .ok()
234            .unwrap_or_default();
235
236        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
237            &assist
238                .prompt_editor
239                .clone()
240                .context("invalid assist")?
241                .read(cx)
242                .prompt(cx),
243            shell.as_deref(),
244            working_directory.as_deref(),
245            &latest_output,
246        )?;
247
248        let mut request_message = LanguageModelRequestMessage {
249            role: Role::User,
250            content: vec![],
251            cache: false,
252        };
253
254        attach_context_to_message(
255            &mut request_message,
256            assist.context_store.read(cx).context().iter(),
257            cx,
258        );
259
260        request_message.content.push(prompt.into());
261
262        Ok(LanguageModelRequest {
263            messages: vec![request_message],
264            tools: Vec::new(),
265            stop: Vec::new(),
266            temperature: None,
267        })
268    }
269
270    fn finish_assist(
271        &mut self,
272        assist_id: TerminalInlineAssistId,
273        undo: bool,
274        execute: bool,
275        window: &mut Window,
276        cx: &mut App,
277    ) {
278        self.dismiss_assist(assist_id, window, cx);
279
280        if let Some(assist) = self.assists.remove(&assist_id) {
281            assist
282                .terminal
283                .update(cx, |this, cx| {
284                    this.clear_block_below_cursor(cx);
285                    this.focus_handle(cx).focus(window);
286                })
287                .log_err();
288
289            if let Some(ConfiguredModel { model, .. }) =
290                LanguageModelRegistry::read_global(cx).inline_assistant_model()
291            {
292                let codegen = assist.codegen.read(cx);
293                let executor = cx.background_executor().clone();
294                report_assistant_event(
295                    AssistantEventData {
296                        conversation_id: None,
297                        kind: AssistantKind::InlineTerminal,
298                        message_id: codegen.message_id.clone(),
299                        phase: if undo {
300                            AssistantPhase::Rejected
301                        } else {
302                            AssistantPhase::Accepted
303                        },
304                        model: model.telemetry_id(),
305                        model_provider: model.provider_id().to_string(),
306                        response_latency: None,
307                        error_message: None,
308                        language_name: None,
309                    },
310                    codegen.telemetry.clone(),
311                    cx.http_client(),
312                    model.api_key(cx),
313                    &executor,
314                );
315            }
316
317            assist.codegen.update(cx, |codegen, cx| {
318                if undo {
319                    codegen.undo(cx);
320                } else if execute {
321                    codegen.complete(cx);
322                }
323            });
324        }
325    }
326
327    fn dismiss_assist(
328        &mut self,
329        assist_id: TerminalInlineAssistId,
330        window: &mut Window,
331        cx: &mut App,
332    ) -> bool {
333        let Some(assist) = self.assists.get_mut(&assist_id) else {
334            return false;
335        };
336        if assist.prompt_editor.is_none() {
337            return false;
338        }
339        assist.prompt_editor = None;
340        assist
341            .terminal
342            .update(cx, |this, cx| {
343                this.clear_block_below_cursor(cx);
344                this.focus_handle(cx).focus(window);
345            })
346            .is_ok()
347    }
348
349    fn insert_prompt_editor_into_terminal(
350        &mut self,
351        assist_id: TerminalInlineAssistId,
352        height: u8,
353        window: &mut Window,
354        cx: &mut App,
355    ) {
356        if let Some(assist) = self.assists.get_mut(&assist_id) {
357            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
358                assist
359                    .terminal
360                    .update(cx, |terminal, cx| {
361                        terminal.clear_block_below_cursor(cx);
362                        let block = terminal_view::BlockProperties {
363                            height,
364                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
365                        };
366                        terminal.set_block_below_cursor(block, window, cx);
367                    })
368                    .log_err();
369            }
370        }
371    }
372}
373
374struct TerminalInlineAssist {
375    terminal: WeakEntity<TerminalView>,
376    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
377    codegen: Entity<TerminalCodegen>,
378    workspace: WeakEntity<Workspace>,
379    context_store: Entity<ContextStore>,
380    _subscriptions: Vec<Subscription>,
381}
382
383impl TerminalInlineAssist {
384    pub fn new(
385        assist_id: TerminalInlineAssistId,
386        terminal: &Entity<TerminalView>,
387        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
388        workspace: WeakEntity<Workspace>,
389        context_store: Entity<ContextStore>,
390        window: &mut Window,
391        cx: &mut App,
392    ) -> Self {
393        let codegen = prompt_editor.read(cx).codegen().clone();
394        Self {
395            terminal: terminal.downgrade(),
396            prompt_editor: Some(prompt_editor.clone()),
397            codegen: codegen.clone(),
398            workspace: workspace.clone(),
399            context_store,
400            _subscriptions: vec![
401                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
402                    TerminalInlineAssistant::update_global(cx, |this, cx| {
403                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
404                    })
405                }),
406                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
407                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
408                        CodegenEvent::Finished => {
409                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
410                                assist
411                            } else {
412                                return;
413                            };
414
415                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
416                                if assist.prompt_editor.is_none() {
417                                    if let Some(workspace) = assist.workspace.upgrade() {
418                                        let error =
419                                            format!("Terminal inline assistant error: {}", error);
420                                        workspace.update(cx, |workspace, cx| {
421                                            struct InlineAssistantError;
422
423                                            let id =
424                                                NotificationId::composite::<InlineAssistantError>(
425                                                    assist_id.0,
426                                                );
427
428                                            workspace.show_toast(Toast::new(id, error), cx);
429                                        })
430                                    }
431                                }
432                            }
433
434                            if assist.prompt_editor.is_none() {
435                                this.finish_assist(assist_id, false, false, window, cx);
436                            }
437                        }
438                    })
439                }),
440            ],
441        }
442    }
443}