terminal_inline_assistant.rs

  1use crate::context::attach_context_to_message;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
  7use crate::thread_store::ThreadStore;
  8use anyhow::{Context as _, Result};
  9use client::telemetry::Telemetry;
 10use collections::{HashMap, VecDeque};
 11use editor::{MultiBuffer, actions::SelectAll};
 12use fs::Fs;
 13use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
 14use language::Buffer;
 15use language_model::{
 16    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 17    report_assistant_event,
 18};
 19use prompt_store::PromptBuilder;
 20use std::sync::Arc;
 21use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
 22use terminal_view::TerminalView;
 23use ui::prelude::*;
 24use util::ResultExt;
 25use workspace::{Toast, Workspace, notifications::NotificationId};
 26
 27pub fn init(
 28    fs: Arc<dyn Fs>,
 29    prompt_builder: Arc<PromptBuilder>,
 30    telemetry: Arc<Telemetry>,
 31    cx: &mut App,
 32) {
 33    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 34}
 35
 36const DEFAULT_CONTEXT_LINES: usize = 50;
 37const PROMPT_HISTORY_MAX_LEN: usize = 20;
 38
 39pub struct TerminalInlineAssistant {
 40    next_assist_id: TerminalInlineAssistId,
 41    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 42    prompt_history: VecDeque<String>,
 43    telemetry: Option<Arc<Telemetry>>,
 44    fs: Arc<dyn Fs>,
 45    prompt_builder: Arc<PromptBuilder>,
 46}
 47
 48impl Global for TerminalInlineAssistant {}
 49
 50impl TerminalInlineAssistant {
 51    pub fn new(
 52        fs: Arc<dyn Fs>,
 53        prompt_builder: Arc<PromptBuilder>,
 54        telemetry: Arc<Telemetry>,
 55    ) -> Self {
 56        Self {
 57            next_assist_id: TerminalInlineAssistId::default(),
 58            assists: HashMap::default(),
 59            prompt_history: VecDeque::default(),
 60            telemetry: Some(telemetry),
 61            fs,
 62            prompt_builder,
 63        }
 64    }
 65
 66    pub fn assist(
 67        &mut self,
 68        terminal_view: &Entity<TerminalView>,
 69        workspace: WeakEntity<Workspace>,
 70        thread_store: Option<WeakEntity<ThreadStore>>,
 71        window: &mut Window,
 72        cx: &mut App,
 73    ) {
 74        let terminal = terminal_view.read(cx).terminal().clone();
 75        let assist_id = self.next_assist_id.post_inc();
 76        let prompt_buffer =
 77            cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx));
 78        let context_store = cx.new(|_cx| ContextStore::new(workspace.clone()));
 79        let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 80
 81        let prompt_editor = cx.new(|cx| {
 82            PromptEditor::new_terminal(
 83                assist_id,
 84                self.prompt_history.clone(),
 85                prompt_buffer.clone(),
 86                codegen,
 87                self.fs.clone(),
 88                context_store.clone(),
 89                workspace.clone(),
 90                thread_store.clone(),
 91                window,
 92                cx,
 93            )
 94        });
 95        let prompt_editor_render = prompt_editor.clone();
 96        let block = terminal_view::BlockProperties {
 97            height: 2,
 98            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
 99        };
100        terminal_view.update(cx, |terminal_view, cx| {
101            terminal_view.set_block_below_cursor(block, window, cx);
102        });
103
104        let terminal_assistant = TerminalInlineAssist::new(
105            assist_id,
106            terminal_view,
107            prompt_editor,
108            workspace.clone(),
109            context_store,
110            window,
111            cx,
112        );
113
114        self.assists.insert(assist_id, terminal_assistant);
115
116        self.focus_assist(assist_id, window, cx);
117    }
118
119    fn focus_assist(
120        &mut self,
121        assist_id: TerminalInlineAssistId,
122        window: &mut Window,
123        cx: &mut App,
124    ) {
125        let assist = &self.assists[&assist_id];
126        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
127            prompt_editor.update(cx, |this, cx| {
128                this.editor.update(cx, |editor, cx| {
129                    window.focus(&editor.focus_handle(cx));
130                    editor.select_all(&SelectAll, window, cx);
131                });
132            });
133        }
134    }
135
136    fn handle_prompt_editor_event(
137        &mut self,
138        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
139        event: &PromptEditorEvent,
140        window: &mut Window,
141        cx: &mut App,
142    ) {
143        let assist_id = prompt_editor.read(cx).id();
144        match event {
145            PromptEditorEvent::StartRequested => {
146                self.start_assist(assist_id, cx);
147            }
148            PromptEditorEvent::StopRequested => {
149                self.stop_assist(assist_id, cx);
150            }
151            PromptEditorEvent::ConfirmRequested { execute } => {
152                self.finish_assist(assist_id, false, *execute, window, cx);
153            }
154            PromptEditorEvent::CancelRequested => {
155                self.finish_assist(assist_id, true, false, window, cx);
156            }
157            PromptEditorEvent::DismissRequested => {
158                self.dismiss_assist(assist_id, window, cx);
159            }
160            PromptEditorEvent::Resized { height_in_lines } => {
161                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
162            }
163        }
164    }
165
166    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
167        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
168            assist
169        } else {
170            return;
171        };
172
173        let Some(user_prompt) = assist
174            .prompt_editor
175            .as_ref()
176            .map(|editor| editor.read(cx).prompt(cx))
177        else {
178            return;
179        };
180
181        self.prompt_history.retain(|prompt| *prompt != user_prompt);
182        self.prompt_history.push_back(user_prompt.clone());
183        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
184            self.prompt_history.pop_front();
185        }
186
187        assist
188            .terminal
189            .update(cx, |terminal, cx| {
190                terminal
191                    .terminal()
192                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
193            })
194            .log_err();
195
196        let codegen = assist.codegen.clone();
197        let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
198            return;
199        };
200
201        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
202    }
203
204    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
205        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
206            assist
207        } else {
208            return;
209        };
210
211        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
212    }
213
214    fn request_for_inline_assist(
215        &self,
216        assist_id: TerminalInlineAssistId,
217        cx: &mut App,
218    ) -> Result<LanguageModelRequest> {
219        let assist = self.assists.get(&assist_id).context("invalid assist")?;
220
221        let shell = std::env::var("SHELL").ok();
222        let (latest_output, working_directory) = assist
223            .terminal
224            .update(cx, |terminal, cx| {
225                let terminal = terminal.entity().read(cx);
226                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
227                let working_directory = terminal
228                    .working_directory()
229                    .map(|path| path.to_string_lossy().to_string());
230                (latest_output, working_directory)
231            })
232            .ok()
233            .unwrap_or_default();
234
235        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
236            &assist
237                .prompt_editor
238                .clone()
239                .context("invalid assist")?
240                .read(cx)
241                .prompt(cx),
242            shell.as_deref(),
243            working_directory.as_deref(),
244            &latest_output,
245        )?;
246
247        let mut request_message = LanguageModelRequestMessage {
248            role: Role::User,
249            content: vec![],
250            cache: false,
251        };
252
253        attach_context_to_message(
254            &mut request_message,
255            assist.context_store.read(cx).snapshot(cx),
256        );
257
258        request_message.content.push(prompt.into());
259
260        Ok(LanguageModelRequest {
261            messages: vec![request_message],
262            tools: Vec::new(),
263            stop: Vec::new(),
264            temperature: None,
265        })
266    }
267
268    fn finish_assist(
269        &mut self,
270        assist_id: TerminalInlineAssistId,
271        undo: bool,
272        execute: bool,
273        window: &mut Window,
274        cx: &mut App,
275    ) {
276        self.dismiss_assist(assist_id, window, cx);
277
278        if let Some(assist) = self.assists.remove(&assist_id) {
279            assist
280                .terminal
281                .update(cx, |this, cx| {
282                    this.clear_block_below_cursor(cx);
283                    this.focus_handle(cx).focus(window);
284                })
285                .log_err();
286
287            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
288                let codegen = assist.codegen.read(cx);
289                let executor = cx.background_executor().clone();
290                report_assistant_event(
291                    AssistantEvent {
292                        conversation_id: None,
293                        kind: AssistantKind::InlineTerminal,
294                        message_id: codegen.message_id.clone(),
295                        phase: if undo {
296                            AssistantPhase::Rejected
297                        } else {
298                            AssistantPhase::Accepted
299                        },
300                        model: model.telemetry_id(),
301                        model_provider: model.provider_id().to_string(),
302                        response_latency: None,
303                        error_message: None,
304                        language_name: None,
305                    },
306                    codegen.telemetry.clone(),
307                    cx.http_client(),
308                    model.api_key(cx),
309                    &executor,
310                );
311            }
312
313            assist.codegen.update(cx, |codegen, cx| {
314                if undo {
315                    codegen.undo(cx);
316                } else if execute {
317                    codegen.complete(cx);
318                }
319            });
320        }
321    }
322
323    fn dismiss_assist(
324        &mut self,
325        assist_id: TerminalInlineAssistId,
326        window: &mut Window,
327        cx: &mut App,
328    ) -> bool {
329        let Some(assist) = self.assists.get_mut(&assist_id) else {
330            return false;
331        };
332        if assist.prompt_editor.is_none() {
333            return false;
334        }
335        assist.prompt_editor = None;
336        assist
337            .terminal
338            .update(cx, |this, cx| {
339                this.clear_block_below_cursor(cx);
340                this.focus_handle(cx).focus(window);
341            })
342            .is_ok()
343    }
344
345    fn insert_prompt_editor_into_terminal(
346        &mut self,
347        assist_id: TerminalInlineAssistId,
348        height: u8,
349        window: &mut Window,
350        cx: &mut App,
351    ) {
352        if let Some(assist) = self.assists.get_mut(&assist_id) {
353            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
354                assist
355                    .terminal
356                    .update(cx, |terminal, cx| {
357                        terminal.clear_block_below_cursor(cx);
358                        let block = terminal_view::BlockProperties {
359                            height,
360                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
361                        };
362                        terminal.set_block_below_cursor(block, window, cx);
363                    })
364                    .log_err();
365            }
366        }
367    }
368}
369
370struct TerminalInlineAssist {
371    terminal: WeakEntity<TerminalView>,
372    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
373    codegen: Entity<TerminalCodegen>,
374    workspace: WeakEntity<Workspace>,
375    context_store: Entity<ContextStore>,
376    _subscriptions: Vec<Subscription>,
377}
378
379impl TerminalInlineAssist {
380    pub fn new(
381        assist_id: TerminalInlineAssistId,
382        terminal: &Entity<TerminalView>,
383        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
384        workspace: WeakEntity<Workspace>,
385        context_store: Entity<ContextStore>,
386        window: &mut Window,
387        cx: &mut App,
388    ) -> Self {
389        let codegen = prompt_editor.read(cx).codegen().clone();
390        Self {
391            terminal: terminal.downgrade(),
392            prompt_editor: Some(prompt_editor.clone()),
393            codegen: codegen.clone(),
394            workspace: workspace.clone(),
395            context_store,
396            _subscriptions: vec![
397                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
398                    TerminalInlineAssistant::update_global(cx, |this, cx| {
399                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
400                    })
401                }),
402                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
403                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
404                        CodegenEvent::Finished => {
405                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
406                                assist
407                            } else {
408                                return;
409                            };
410
411                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
412                                if assist.prompt_editor.is_none() {
413                                    if let Some(workspace) = assist.workspace.upgrade() {
414                                        let error =
415                                            format!("Terminal inline assistant error: {}", error);
416                                        workspace.update(cx, |workspace, cx| {
417                                            struct InlineAssistantError;
418
419                                            let id =
420                                                NotificationId::composite::<InlineAssistantError>(
421                                                    assist_id.0,
422                                                );
423
424                                            workspace.show_toast(Toast::new(id, error), cx);
425                                        })
426                                    }
427                                }
428                            }
429
430                            if assist.prompt_editor.is_none() {
431                                this.finish_assist(assist_id, false, false, window, cx);
432                            }
433                        }
434                    })
435                }),
436            ],
437        }
438    }
439}