terminal_inline_assistant.rs

  1use crate::context::load_context;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
  7use crate::thread_store::ThreadStore;
  8use anyhow::{Context as _, Result};
  9use client::telemetry::Telemetry;
 10use collections::{HashMap, VecDeque};
 11use editor::{MultiBuffer, actions::SelectAll};
 12use fs::Fs;
 13use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
 14use language::Buffer;
 15use language_model::{
 16    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 17    Role, report_assistant_event,
 18};
 19use project::Project;
 20use prompt_store::{PromptBuilder, PromptStore};
 21use std::sync::Arc;
 22use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 23use terminal_view::TerminalView;
 24use ui::prelude::*;
 25use util::ResultExt;
 26use workspace::{Toast, Workspace, notifications::NotificationId};
 27
 28pub fn init(
 29    fs: Arc<dyn Fs>,
 30    prompt_builder: Arc<PromptBuilder>,
 31    telemetry: Arc<Telemetry>,
 32    cx: &mut App,
 33) {
 34    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 35}
 36
 37const DEFAULT_CONTEXT_LINES: usize = 50;
 38const PROMPT_HISTORY_MAX_LEN: usize = 20;
 39
 40pub struct TerminalInlineAssistant {
 41    next_assist_id: TerminalInlineAssistId,
 42    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 43    prompt_history: VecDeque<String>,
 44    telemetry: Option<Arc<Telemetry>>,
 45    fs: Arc<dyn Fs>,
 46    prompt_builder: Arc<PromptBuilder>,
 47}
 48
 49impl Global for TerminalInlineAssistant {}
 50
 51impl TerminalInlineAssistant {
 52    pub fn new(
 53        fs: Arc<dyn Fs>,
 54        prompt_builder: Arc<PromptBuilder>,
 55        telemetry: Arc<Telemetry>,
 56    ) -> Self {
 57        Self {
 58            next_assist_id: TerminalInlineAssistId::default(),
 59            assists: HashMap::default(),
 60            prompt_history: VecDeque::default(),
 61            telemetry: Some(telemetry),
 62            fs,
 63            prompt_builder,
 64        }
 65    }
 66
 67    pub fn assist(
 68        &mut self,
 69        terminal_view: &Entity<TerminalView>,
 70        workspace: WeakEntity<Workspace>,
 71        project: WeakEntity<Project>,
 72        prompt_store: Option<Entity<PromptStore>>,
 73        thread_store: Option<WeakEntity<ThreadStore>>,
 74        window: &mut Window,
 75        cx: &mut App,
 76    ) {
 77        let terminal = terminal_view.read(cx).terminal().clone();
 78        let assist_id = self.next_assist_id.post_inc();
 79        let prompt_buffer =
 80            cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx));
 81        let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
 82        let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 83
 84        let prompt_editor = cx.new(|cx| {
 85            PromptEditor::new_terminal(
 86                assist_id,
 87                self.prompt_history.clone(),
 88                prompt_buffer.clone(),
 89                codegen,
 90                self.fs.clone(),
 91                context_store.clone(),
 92                workspace.clone(),
 93                thread_store.clone(),
 94                window,
 95                cx,
 96            )
 97        });
 98        let prompt_editor_render = prompt_editor.clone();
 99        let block = terminal_view::BlockProperties {
100            height: 2,
101            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
102        };
103        terminal_view.update(cx, |terminal_view, cx| {
104            terminal_view.set_block_below_cursor(block, window, cx);
105        });
106
107        let terminal_assistant = TerminalInlineAssist::new(
108            assist_id,
109            terminal_view,
110            prompt_editor,
111            workspace.clone(),
112            context_store,
113            prompt_store,
114            window,
115            cx,
116        );
117
118        self.assists.insert(assist_id, terminal_assistant);
119
120        self.focus_assist(assist_id, window, cx);
121    }
122
123    fn focus_assist(
124        &mut self,
125        assist_id: TerminalInlineAssistId,
126        window: &mut Window,
127        cx: &mut App,
128    ) {
129        let assist = &self.assists[&assist_id];
130        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
131            prompt_editor.update(cx, |this, cx| {
132                this.editor.update(cx, |editor, cx| {
133                    window.focus(&editor.focus_handle(cx));
134                    editor.select_all(&SelectAll, window, cx);
135                });
136            });
137        }
138    }
139
140    fn handle_prompt_editor_event(
141        &mut self,
142        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
143        event: &PromptEditorEvent,
144        window: &mut Window,
145        cx: &mut App,
146    ) {
147        let assist_id = prompt_editor.read(cx).id();
148        match event {
149            PromptEditorEvent::StartRequested => {
150                self.start_assist(assist_id, cx);
151            }
152            PromptEditorEvent::StopRequested => {
153                self.stop_assist(assist_id, cx);
154            }
155            PromptEditorEvent::ConfirmRequested { execute } => {
156                self.finish_assist(assist_id, false, *execute, window, cx);
157            }
158            PromptEditorEvent::CancelRequested => {
159                self.finish_assist(assist_id, true, false, window, cx);
160            }
161            PromptEditorEvent::DismissRequested => {
162                self.dismiss_assist(assist_id, window, cx);
163            }
164            PromptEditorEvent::Resized { height_in_lines } => {
165                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
166            }
167        }
168    }
169
170    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
171        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
172            assist
173        } else {
174            return;
175        };
176
177        let Some(user_prompt) = assist
178            .prompt_editor
179            .as_ref()
180            .map(|editor| editor.read(cx).prompt(cx))
181        else {
182            return;
183        };
184
185        self.prompt_history.retain(|prompt| *prompt != user_prompt);
186        self.prompt_history.push_back(user_prompt.clone());
187        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
188            self.prompt_history.pop_front();
189        }
190
191        assist
192            .terminal
193            .update(cx, |terminal, cx| {
194                terminal
195                    .terminal()
196                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
197            })
198            .log_err();
199
200        let codegen = assist.codegen.clone();
201        let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
202            return;
203        };
204
205        codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
206    }
207
208    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
209        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
210            assist
211        } else {
212            return;
213        };
214
215        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
216    }
217
218    fn request_for_inline_assist(
219        &self,
220        assist_id: TerminalInlineAssistId,
221        cx: &mut App,
222    ) -> Result<Task<LanguageModelRequest>> {
223        let assist = self.assists.get(&assist_id).context("invalid assist")?;
224
225        let shell = std::env::var("SHELL").ok();
226        let (latest_output, working_directory) = assist
227            .terminal
228            .update(cx, |terminal, cx| {
229                let terminal = terminal.entity().read(cx);
230                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
231                let working_directory = terminal
232                    .working_directory()
233                    .map(|path| path.to_string_lossy().to_string());
234                (latest_output, working_directory)
235            })
236            .ok()
237            .unwrap_or_default();
238
239        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
240            &assist
241                .prompt_editor
242                .clone()
243                .context("invalid assist")?
244                .read(cx)
245                .prompt(cx),
246            shell.as_deref(),
247            working_directory.as_deref(),
248            &latest_output,
249        )?;
250
251        let contexts = assist
252            .context_store
253            .read(cx)
254            .context()
255            .cloned()
256            .collect::<Vec<_>>();
257        let context_load_task = assist.workspace.update(cx, |workspace, cx| {
258            let project = workspace.project();
259            load_context(contexts, project, &assist.prompt_store, cx)
260        })?;
261
262        Ok(cx.background_spawn(async move {
263            let mut request_message = LanguageModelRequestMessage {
264                role: Role::User,
265                content: vec![],
266                cache: false,
267            };
268
269            context_load_task
270                .await
271                .loaded_context
272                .add_to_request_message(&mut request_message);
273
274            request_message.content.push(prompt.into());
275
276            LanguageModelRequest {
277                thread_id: None,
278                prompt_id: None,
279                mode: None,
280                messages: vec![request_message],
281                tools: Vec::new(),
282                stop: Vec::new(),
283                temperature: None,
284            }
285        }))
286    }
287
288    fn finish_assist(
289        &mut self,
290        assist_id: TerminalInlineAssistId,
291        undo: bool,
292        execute: bool,
293        window: &mut Window,
294        cx: &mut App,
295    ) {
296        self.dismiss_assist(assist_id, window, cx);
297
298        if let Some(assist) = self.assists.remove(&assist_id) {
299            assist
300                .terminal
301                .update(cx, |this, cx| {
302                    this.clear_block_below_cursor(cx);
303                    this.focus_handle(cx).focus(window);
304                })
305                .log_err();
306
307            if let Some(ConfiguredModel { model, .. }) =
308                LanguageModelRegistry::read_global(cx).inline_assistant_model()
309            {
310                let codegen = assist.codegen.read(cx);
311                let executor = cx.background_executor().clone();
312                report_assistant_event(
313                    AssistantEventData {
314                        conversation_id: None,
315                        kind: AssistantKind::InlineTerminal,
316                        message_id: codegen.message_id.clone(),
317                        phase: if undo {
318                            AssistantPhase::Rejected
319                        } else {
320                            AssistantPhase::Accepted
321                        },
322                        model: model.telemetry_id(),
323                        model_provider: model.provider_id().to_string(),
324                        response_latency: None,
325                        error_message: None,
326                        language_name: None,
327                    },
328                    codegen.telemetry.clone(),
329                    cx.http_client(),
330                    model.api_key(cx),
331                    &executor,
332                );
333            }
334
335            assist.codegen.update(cx, |codegen, cx| {
336                if undo {
337                    codegen.undo(cx);
338                } else if execute {
339                    codegen.complete(cx);
340                }
341            });
342        }
343    }
344
345    fn dismiss_assist(
346        &mut self,
347        assist_id: TerminalInlineAssistId,
348        window: &mut Window,
349        cx: &mut App,
350    ) -> bool {
351        let Some(assist) = self.assists.get_mut(&assist_id) else {
352            return false;
353        };
354        if assist.prompt_editor.is_none() {
355            return false;
356        }
357        assist.prompt_editor = None;
358        assist
359            .terminal
360            .update(cx, |this, cx| {
361                this.clear_block_below_cursor(cx);
362                this.focus_handle(cx).focus(window);
363            })
364            .is_ok()
365    }
366
367    fn insert_prompt_editor_into_terminal(
368        &mut self,
369        assist_id: TerminalInlineAssistId,
370        height: u8,
371        window: &mut Window,
372        cx: &mut App,
373    ) {
374        if let Some(assist) = self.assists.get_mut(&assist_id) {
375            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
376                assist
377                    .terminal
378                    .update(cx, |terminal, cx| {
379                        terminal.clear_block_below_cursor(cx);
380                        let block = terminal_view::BlockProperties {
381                            height,
382                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
383                        };
384                        terminal.set_block_below_cursor(block, window, cx);
385                    })
386                    .log_err();
387            }
388        }
389    }
390}
391
392struct TerminalInlineAssist {
393    terminal: WeakEntity<TerminalView>,
394    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
395    codegen: Entity<TerminalCodegen>,
396    workspace: WeakEntity<Workspace>,
397    context_store: Entity<ContextStore>,
398    prompt_store: Option<Entity<PromptStore>>,
399    _subscriptions: Vec<Subscription>,
400}
401
402impl TerminalInlineAssist {
403    pub fn new(
404        assist_id: TerminalInlineAssistId,
405        terminal: &Entity<TerminalView>,
406        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
407        workspace: WeakEntity<Workspace>,
408        context_store: Entity<ContextStore>,
409        prompt_store: Option<Entity<PromptStore>>,
410        window: &mut Window,
411        cx: &mut App,
412    ) -> Self {
413        let codegen = prompt_editor.read(cx).codegen().clone();
414        Self {
415            terminal: terminal.downgrade(),
416            prompt_editor: Some(prompt_editor.clone()),
417            codegen: codegen.clone(),
418            workspace: workspace.clone(),
419            context_store,
420            prompt_store,
421            _subscriptions: vec![
422                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
423                    TerminalInlineAssistant::update_global(cx, |this, cx| {
424                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
425                    })
426                }),
427                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
428                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
429                        CodegenEvent::Finished => {
430                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
431                                assist
432                            } else {
433                                return;
434                            };
435
436                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
437                                if assist.prompt_editor.is_none() {
438                                    if let Some(workspace) = assist.workspace.upgrade() {
439                                        let error =
440                                            format!("Terminal inline assistant error: {}", error);
441                                        workspace.update(cx, |workspace, cx| {
442                                            struct InlineAssistantError;
443
444                                            let id =
445                                                NotificationId::composite::<InlineAssistantError>(
446                                                    assist_id.0,
447                                                );
448
449                                            workspace.show_toast(Toast::new(id, error), cx);
450                                        })
451                                    }
452                                }
453                            }
454
455                            if assist.prompt_editor.is_none() {
456                                this.finish_assist(assist_id, false, false, window, cx);
457                            }
458                        }
459                    })
460                }),
461            ],
462        }
463    }
464}