terminal_inline_assistant.rs

  1use crate::context::attach_context_to_message;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
  7use crate::thread_store::ThreadStore;
  8use anyhow::{Context as _, Result};
  9use client::telemetry::Telemetry;
 10use collections::{HashMap, VecDeque};
 11use editor::{MultiBuffer, actions::SelectAll};
 12use fs::Fs;
 13use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
 14use language::Buffer;
 15use language_model::{
 16    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 17    report_assistant_event,
 18};
 19use prompt_store::PromptBuilder;
 20use std::sync::Arc;
 21use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
 22use terminal_view::TerminalView;
 23use ui::prelude::*;
 24use util::ResultExt;
 25use workspace::{Toast, Workspace, notifications::NotificationId};
 26
 27pub fn init(
 28    fs: Arc<dyn Fs>,
 29    prompt_builder: Arc<PromptBuilder>,
 30    telemetry: Arc<Telemetry>,
 31    cx: &mut App,
 32) {
 33    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 34}
 35
 36const DEFAULT_CONTEXT_LINES: usize = 50;
 37const PROMPT_HISTORY_MAX_LEN: usize = 20;
 38
 39pub struct TerminalInlineAssistant {
 40    next_assist_id: TerminalInlineAssistId,
 41    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 42    prompt_history: VecDeque<String>,
 43    telemetry: Option<Arc<Telemetry>>,
 44    fs: Arc<dyn Fs>,
 45    prompt_builder: Arc<PromptBuilder>,
 46}
 47
 48impl Global for TerminalInlineAssistant {}
 49
 50impl TerminalInlineAssistant {
 51    pub fn new(
 52        fs: Arc<dyn Fs>,
 53        prompt_builder: Arc<PromptBuilder>,
 54        telemetry: Arc<Telemetry>,
 55    ) -> Self {
 56        Self {
 57            next_assist_id: TerminalInlineAssistId::default(),
 58            assists: HashMap::default(),
 59            prompt_history: VecDeque::default(),
 60            telemetry: Some(telemetry),
 61            fs,
 62            prompt_builder,
 63        }
 64    }
 65
 66    pub fn assist(
 67        &mut self,
 68        terminal_view: &Entity<TerminalView>,
 69        workspace: WeakEntity<Workspace>,
 70        thread_store: Option<WeakEntity<ThreadStore>>,
 71        window: &mut Window,
 72        cx: &mut App,
 73    ) {
 74        let terminal = terminal_view.read(cx).terminal().clone();
 75        let assist_id = self.next_assist_id.post_inc();
 76        let prompt_buffer =
 77            cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx));
 78        let context_store = cx.new(|_cx| ContextStore::new(workspace.clone()));
 79        let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 80
 81        let prompt_editor = cx.new(|cx| {
 82            PromptEditor::new_terminal(
 83                assist_id,
 84                self.prompt_history.clone(),
 85                prompt_buffer.clone(),
 86                codegen,
 87                self.fs.clone(),
 88                context_store.clone(),
 89                workspace.clone(),
 90                thread_store.clone(),
 91                window,
 92                cx,
 93            )
 94        });
 95        let prompt_editor_render = prompt_editor.clone();
 96        let block = terminal_view::BlockProperties {
 97            height: 2,
 98            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
 99        };
100        terminal_view.update(cx, |terminal_view, cx| {
101            terminal_view.set_block_below_cursor(block, window, cx);
102        });
103
104        let terminal_assistant = TerminalInlineAssist::new(
105            assist_id,
106            terminal_view,
107            prompt_editor,
108            workspace.clone(),
109            context_store,
110            window,
111            cx,
112        );
113
114        self.assists.insert(assist_id, terminal_assistant);
115
116        self.focus_assist(assist_id, window, cx);
117    }
118
119    fn focus_assist(
120        &mut self,
121        assist_id: TerminalInlineAssistId,
122        window: &mut Window,
123        cx: &mut App,
124    ) {
125        let assist = &self.assists[&assist_id];
126        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
127            prompt_editor.update(cx, |this, cx| {
128                this.editor.update(cx, |editor, cx| {
129                    window.focus(&editor.focus_handle(cx));
130                    editor.select_all(&SelectAll, window, cx);
131                });
132            });
133        }
134    }
135
136    fn handle_prompt_editor_event(
137        &mut self,
138        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
139        event: &PromptEditorEvent,
140        window: &mut Window,
141        cx: &mut App,
142    ) {
143        let assist_id = prompt_editor.read(cx).id();
144        match event {
145            PromptEditorEvent::StartRequested => {
146                self.start_assist(assist_id, cx);
147            }
148            PromptEditorEvent::StopRequested => {
149                self.stop_assist(assist_id, cx);
150            }
151            PromptEditorEvent::ConfirmRequested { execute } => {
152                self.finish_assist(assist_id, false, *execute, window, cx);
153            }
154            PromptEditorEvent::CancelRequested => {
155                self.finish_assist(assist_id, true, false, window, cx);
156            }
157            PromptEditorEvent::DismissRequested => {
158                self.dismiss_assist(assist_id, window, cx);
159            }
160            PromptEditorEvent::Resized { height_in_lines } => {
161                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
162            }
163        }
164    }
165
166    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
167        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
168            assist
169        } else {
170            return;
171        };
172
173        let Some(user_prompt) = assist
174            .prompt_editor
175            .as_ref()
176            .map(|editor| editor.read(cx).prompt(cx))
177        else {
178            return;
179        };
180
181        self.prompt_history.retain(|prompt| *prompt != user_prompt);
182        self.prompt_history.push_back(user_prompt.clone());
183        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
184            self.prompt_history.pop_front();
185        }
186
187        assist
188            .terminal
189            .update(cx, |terminal, cx| {
190                terminal
191                    .terminal()
192                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
193            })
194            .log_err();
195
196        let codegen = assist.codegen.clone();
197        let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
198            return;
199        };
200
201        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
202    }
203
204    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
205        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
206            assist
207        } else {
208            return;
209        };
210
211        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
212    }
213
214    fn request_for_inline_assist(
215        &self,
216        assist_id: TerminalInlineAssistId,
217        cx: &mut App,
218    ) -> Result<LanguageModelRequest> {
219        let assist = self.assists.get(&assist_id).context("invalid assist")?;
220
221        let shell = std::env::var("SHELL").ok();
222        let (latest_output, working_directory) = assist
223            .terminal
224            .update(cx, |terminal, cx| {
225                let terminal = terminal.entity().read(cx);
226                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
227                let working_directory = terminal
228                    .working_directory()
229                    .map(|path| path.to_string_lossy().to_string());
230                (latest_output, working_directory)
231            })
232            .ok()
233            .unwrap_or_default();
234
235        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
236            &assist
237                .prompt_editor
238                .clone()
239                .context("invalid assist")?
240                .read(cx)
241                .prompt(cx),
242            shell.as_deref(),
243            working_directory.as_deref(),
244            &latest_output,
245        )?;
246
247        let mut request_message = LanguageModelRequestMessage {
248            role: Role::User,
249            content: vec![],
250            cache: false,
251        };
252
253        attach_context_to_message(
254            &mut request_message,
255            assist.context_store.read(cx).context().iter(),
256            cx,
257        );
258
259        request_message.content.push(prompt.into());
260
261        Ok(LanguageModelRequest {
262            messages: vec![request_message],
263            tools: Vec::new(),
264            stop: Vec::new(),
265            temperature: None,
266        })
267    }
268
269    fn finish_assist(
270        &mut self,
271        assist_id: TerminalInlineAssistId,
272        undo: bool,
273        execute: bool,
274        window: &mut Window,
275        cx: &mut App,
276    ) {
277        self.dismiss_assist(assist_id, window, cx);
278
279        if let Some(assist) = self.assists.remove(&assist_id) {
280            assist
281                .terminal
282                .update(cx, |this, cx| {
283                    this.clear_block_below_cursor(cx);
284                    this.focus_handle(cx).focus(window);
285                })
286                .log_err();
287
288            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
289                let codegen = assist.codegen.read(cx);
290                let executor = cx.background_executor().clone();
291                report_assistant_event(
292                    AssistantEvent {
293                        conversation_id: None,
294                        kind: AssistantKind::InlineTerminal,
295                        message_id: codegen.message_id.clone(),
296                        phase: if undo {
297                            AssistantPhase::Rejected
298                        } else {
299                            AssistantPhase::Accepted
300                        },
301                        model: model.telemetry_id(),
302                        model_provider: model.provider_id().to_string(),
303                        response_latency: None,
304                        error_message: None,
305                        language_name: None,
306                    },
307                    codegen.telemetry.clone(),
308                    cx.http_client(),
309                    model.api_key(cx),
310                    &executor,
311                );
312            }
313
314            assist.codegen.update(cx, |codegen, cx| {
315                if undo {
316                    codegen.undo(cx);
317                } else if execute {
318                    codegen.complete(cx);
319                }
320            });
321        }
322    }
323
324    fn dismiss_assist(
325        &mut self,
326        assist_id: TerminalInlineAssistId,
327        window: &mut Window,
328        cx: &mut App,
329    ) -> bool {
330        let Some(assist) = self.assists.get_mut(&assist_id) else {
331            return false;
332        };
333        if assist.prompt_editor.is_none() {
334            return false;
335        }
336        assist.prompt_editor = None;
337        assist
338            .terminal
339            .update(cx, |this, cx| {
340                this.clear_block_below_cursor(cx);
341                this.focus_handle(cx).focus(window);
342            })
343            .is_ok()
344    }
345
346    fn insert_prompt_editor_into_terminal(
347        &mut self,
348        assist_id: TerminalInlineAssistId,
349        height: u8,
350        window: &mut Window,
351        cx: &mut App,
352    ) {
353        if let Some(assist) = self.assists.get_mut(&assist_id) {
354            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
355                assist
356                    .terminal
357                    .update(cx, |terminal, cx| {
358                        terminal.clear_block_below_cursor(cx);
359                        let block = terminal_view::BlockProperties {
360                            height,
361                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
362                        };
363                        terminal.set_block_below_cursor(block, window, cx);
364                    })
365                    .log_err();
366            }
367        }
368    }
369}
370
371struct TerminalInlineAssist {
372    terminal: WeakEntity<TerminalView>,
373    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
374    codegen: Entity<TerminalCodegen>,
375    workspace: WeakEntity<Workspace>,
376    context_store: Entity<ContextStore>,
377    _subscriptions: Vec<Subscription>,
378}
379
380impl TerminalInlineAssist {
381    pub fn new(
382        assist_id: TerminalInlineAssistId,
383        terminal: &Entity<TerminalView>,
384        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
385        workspace: WeakEntity<Workspace>,
386        context_store: Entity<ContextStore>,
387        window: &mut Window,
388        cx: &mut App,
389    ) -> Self {
390        let codegen = prompt_editor.read(cx).codegen().clone();
391        Self {
392            terminal: terminal.downgrade(),
393            prompt_editor: Some(prompt_editor.clone()),
394            codegen: codegen.clone(),
395            workspace: workspace.clone(),
396            context_store,
397            _subscriptions: vec![
398                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
399                    TerminalInlineAssistant::update_global(cx, |this, cx| {
400                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
401                    })
402                }),
403                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
404                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
405                        CodegenEvent::Finished => {
406                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
407                                assist
408                            } else {
409                                return;
410                            };
411
412                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
413                                if assist.prompt_editor.is_none() {
414                                    if let Some(workspace) = assist.workspace.upgrade() {
415                                        let error =
416                                            format!("Terminal inline assistant error: {}", error);
417                                        workspace.update(cx, |workspace, cx| {
418                                            struct InlineAssistantError;
419
420                                            let id =
421                                                NotificationId::composite::<InlineAssistantError>(
422                                                    assist_id.0,
423                                                );
424
425                                            workspace.show_toast(Toast::new(id, error), cx);
426                                        })
427                                    }
428                                }
429                            }
430
431                            if assist.prompt_editor.is_none() {
432                                this.finish_assist(assist_id, false, false, window, cx);
433                            }
434                        }
435                    })
436                }),
437            ],
438        }
439    }
440}