terminal_inline_assistant.rs

  1use crate::context::load_context;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
  7use crate::thread_store::{TextThreadStore, ThreadStore};
  8use anyhow::{Context as _, Result};
  9use client::telemetry::Telemetry;
 10use collections::{HashMap, VecDeque};
 11use editor::{MultiBuffer, actions::SelectAll};
 12use fs::Fs;
 13use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
 14use language::Buffer;
 15use language_model::{
 16    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 17    Role, report_assistant_event,
 18};
 19use project::Project;
 20use prompt_store::{PromptBuilder, PromptStore};
 21use std::sync::Arc;
 22use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 23use terminal_view::TerminalView;
 24use ui::prelude::*;
 25use util::ResultExt;
 26use workspace::{Toast, Workspace, notifications::NotificationId};
 27
 28pub fn init(
 29    fs: Arc<dyn Fs>,
 30    prompt_builder: Arc<PromptBuilder>,
 31    telemetry: Arc<Telemetry>,
 32    cx: &mut App,
 33) {
 34    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 35}
 36
 37const DEFAULT_CONTEXT_LINES: usize = 50;
 38const PROMPT_HISTORY_MAX_LEN: usize = 20;
 39
 40pub struct TerminalInlineAssistant {
 41    next_assist_id: TerminalInlineAssistId,
 42    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 43    prompt_history: VecDeque<String>,
 44    telemetry: Option<Arc<Telemetry>>,
 45    fs: Arc<dyn Fs>,
 46    prompt_builder: Arc<PromptBuilder>,
 47}
 48
 49impl Global for TerminalInlineAssistant {}
 50
 51impl TerminalInlineAssistant {
 52    pub fn new(
 53        fs: Arc<dyn Fs>,
 54        prompt_builder: Arc<PromptBuilder>,
 55        telemetry: Arc<Telemetry>,
 56    ) -> Self {
 57        Self {
 58            next_assist_id: TerminalInlineAssistId::default(),
 59            assists: HashMap::default(),
 60            prompt_history: VecDeque::default(),
 61            telemetry: Some(telemetry),
 62            fs,
 63            prompt_builder,
 64        }
 65    }
 66
 67    pub fn assist(
 68        &mut self,
 69        terminal_view: &Entity<TerminalView>,
 70        workspace: WeakEntity<Workspace>,
 71        project: WeakEntity<Project>,
 72        prompt_store: Option<Entity<PromptStore>>,
 73        thread_store: Option<WeakEntity<ThreadStore>>,
 74        text_thread_store: Option<WeakEntity<TextThreadStore>>,
 75        window: &mut Window,
 76        cx: &mut App,
 77    ) {
 78        let terminal = terminal_view.read(cx).terminal().clone();
 79        let assist_id = self.next_assist_id.post_inc();
 80        let prompt_buffer =
 81            cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx));
 82        let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
 83        let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 84
 85        let prompt_editor = cx.new(|cx| {
 86            PromptEditor::new_terminal(
 87                assist_id,
 88                self.prompt_history.clone(),
 89                prompt_buffer.clone(),
 90                codegen,
 91                self.fs.clone(),
 92                context_store.clone(),
 93                workspace.clone(),
 94                thread_store.clone(),
 95                text_thread_store.clone(),
 96                window,
 97                cx,
 98            )
 99        });
100        let prompt_editor_render = prompt_editor.clone();
101        let block = terminal_view::BlockProperties {
102            height: 2,
103            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
104        };
105        terminal_view.update(cx, |terminal_view, cx| {
106            terminal_view.set_block_below_cursor(block, window, cx);
107        });
108
109        let terminal_assistant = TerminalInlineAssist::new(
110            assist_id,
111            terminal_view,
112            prompt_editor,
113            workspace.clone(),
114            context_store,
115            prompt_store,
116            window,
117            cx,
118        );
119
120        self.assists.insert(assist_id, terminal_assistant);
121
122        self.focus_assist(assist_id, window, cx);
123    }
124
125    fn focus_assist(
126        &mut self,
127        assist_id: TerminalInlineAssistId,
128        window: &mut Window,
129        cx: &mut App,
130    ) {
131        let assist = &self.assists[&assist_id];
132        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
133            prompt_editor.update(cx, |this, cx| {
134                this.editor.update(cx, |editor, cx| {
135                    window.focus(&editor.focus_handle(cx));
136                    editor.select_all(&SelectAll, window, cx);
137                });
138            });
139        }
140    }
141
142    fn handle_prompt_editor_event(
143        &mut self,
144        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
145        event: &PromptEditorEvent,
146        window: &mut Window,
147        cx: &mut App,
148    ) {
149        let assist_id = prompt_editor.read(cx).id();
150        match event {
151            PromptEditorEvent::StartRequested => {
152                self.start_assist(assist_id, cx);
153            }
154            PromptEditorEvent::StopRequested => {
155                self.stop_assist(assist_id, cx);
156            }
157            PromptEditorEvent::ConfirmRequested { execute } => {
158                self.finish_assist(assist_id, false, *execute, window, cx);
159            }
160            PromptEditorEvent::CancelRequested => {
161                self.finish_assist(assist_id, true, false, window, cx);
162            }
163            PromptEditorEvent::DismissRequested => {
164                self.dismiss_assist(assist_id, window, cx);
165            }
166            PromptEditorEvent::Resized { height_in_lines } => {
167                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
168            }
169        }
170    }
171
172    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
173        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
174            assist
175        } else {
176            return;
177        };
178
179        let Some(user_prompt) = assist
180            .prompt_editor
181            .as_ref()
182            .map(|editor| editor.read(cx).prompt(cx))
183        else {
184            return;
185        };
186
187        self.prompt_history.retain(|prompt| *prompt != user_prompt);
188        self.prompt_history.push_back(user_prompt.clone());
189        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
190            self.prompt_history.pop_front();
191        }
192
193        assist
194            .terminal
195            .update(cx, |terminal, cx| {
196                terminal
197                    .terminal()
198                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
199            })
200            .log_err();
201
202        let codegen = assist.codegen.clone();
203        let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
204            return;
205        };
206
207        codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
208    }
209
210    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
211        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
212            assist
213        } else {
214            return;
215        };
216
217        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
218    }
219
220    fn request_for_inline_assist(
221        &self,
222        assist_id: TerminalInlineAssistId,
223        cx: &mut App,
224    ) -> Result<Task<LanguageModelRequest>> {
225        let assist = self.assists.get(&assist_id).context("invalid assist")?;
226
227        let shell = std::env::var("SHELL").ok();
228        let (latest_output, working_directory) = assist
229            .terminal
230            .update(cx, |terminal, cx| {
231                let terminal = terminal.entity().read(cx);
232                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
233                let working_directory = terminal
234                    .working_directory()
235                    .map(|path| path.to_string_lossy().to_string());
236                (latest_output, working_directory)
237            })
238            .ok()
239            .unwrap_or_default();
240
241        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
242            &assist
243                .prompt_editor
244                .clone()
245                .context("invalid assist")?
246                .read(cx)
247                .prompt(cx),
248            shell.as_deref(),
249            working_directory.as_deref(),
250            &latest_output,
251        )?;
252
253        let contexts = assist
254            .context_store
255            .read(cx)
256            .context()
257            .cloned()
258            .collect::<Vec<_>>();
259        let context_load_task = assist.workspace.update(cx, |workspace, cx| {
260            let project = workspace.project();
261            load_context(contexts, project, &assist.prompt_store, cx)
262        })?;
263
264        Ok(cx.background_spawn(async move {
265            let mut request_message = LanguageModelRequestMessage {
266                role: Role::User,
267                content: vec![],
268                cache: false,
269            };
270
271            context_load_task
272                .await
273                .loaded_context
274                .add_to_request_message(&mut request_message);
275
276            request_message.content.push(prompt.into());
277
278            LanguageModelRequest {
279                thread_id: None,
280                prompt_id: None,
281                mode: None,
282                messages: vec![request_message],
283                tools: Vec::new(),
284                stop: Vec::new(),
285                temperature: None,
286            }
287        }))
288    }
289
290    fn finish_assist(
291        &mut self,
292        assist_id: TerminalInlineAssistId,
293        undo: bool,
294        execute: bool,
295        window: &mut Window,
296        cx: &mut App,
297    ) {
298        self.dismiss_assist(assist_id, window, cx);
299
300        if let Some(assist) = self.assists.remove(&assist_id) {
301            assist
302                .terminal
303                .update(cx, |this, cx| {
304                    this.clear_block_below_cursor(cx);
305                    this.focus_handle(cx).focus(window);
306                })
307                .log_err();
308
309            if let Some(ConfiguredModel { model, .. }) =
310                LanguageModelRegistry::read_global(cx).inline_assistant_model()
311            {
312                let codegen = assist.codegen.read(cx);
313                let executor = cx.background_executor().clone();
314                report_assistant_event(
315                    AssistantEventData {
316                        conversation_id: None,
317                        kind: AssistantKind::InlineTerminal,
318                        message_id: codegen.message_id.clone(),
319                        phase: if undo {
320                            AssistantPhase::Rejected
321                        } else {
322                            AssistantPhase::Accepted
323                        },
324                        model: model.telemetry_id(),
325                        model_provider: model.provider_id().to_string(),
326                        response_latency: None,
327                        error_message: None,
328                        language_name: None,
329                    },
330                    codegen.telemetry.clone(),
331                    cx.http_client(),
332                    model.api_key(cx),
333                    &executor,
334                );
335            }
336
337            assist.codegen.update(cx, |codegen, cx| {
338                if undo {
339                    codegen.undo(cx);
340                } else if execute {
341                    codegen.complete(cx);
342                }
343            });
344        }
345    }
346
347    fn dismiss_assist(
348        &mut self,
349        assist_id: TerminalInlineAssistId,
350        window: &mut Window,
351        cx: &mut App,
352    ) -> bool {
353        let Some(assist) = self.assists.get_mut(&assist_id) else {
354            return false;
355        };
356        if assist.prompt_editor.is_none() {
357            return false;
358        }
359        assist.prompt_editor = None;
360        assist
361            .terminal
362            .update(cx, |this, cx| {
363                this.clear_block_below_cursor(cx);
364                this.focus_handle(cx).focus(window);
365            })
366            .is_ok()
367    }
368
369    fn insert_prompt_editor_into_terminal(
370        &mut self,
371        assist_id: TerminalInlineAssistId,
372        height: u8,
373        window: &mut Window,
374        cx: &mut App,
375    ) {
376        if let Some(assist) = self.assists.get_mut(&assist_id) {
377            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
378                assist
379                    .terminal
380                    .update(cx, |terminal, cx| {
381                        terminal.clear_block_below_cursor(cx);
382                        let block = terminal_view::BlockProperties {
383                            height,
384                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
385                        };
386                        terminal.set_block_below_cursor(block, window, cx);
387                    })
388                    .log_err();
389            }
390        }
391    }
392}
393
394struct TerminalInlineAssist {
395    terminal: WeakEntity<TerminalView>,
396    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
397    codegen: Entity<TerminalCodegen>,
398    workspace: WeakEntity<Workspace>,
399    context_store: Entity<ContextStore>,
400    prompt_store: Option<Entity<PromptStore>>,
401    _subscriptions: Vec<Subscription>,
402}
403
404impl TerminalInlineAssist {
405    pub fn new(
406        assist_id: TerminalInlineAssistId,
407        terminal: &Entity<TerminalView>,
408        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
409        workspace: WeakEntity<Workspace>,
410        context_store: Entity<ContextStore>,
411        prompt_store: Option<Entity<PromptStore>>,
412        window: &mut Window,
413        cx: &mut App,
414    ) -> Self {
415        let codegen = prompt_editor.read(cx).codegen().clone();
416        Self {
417            terminal: terminal.downgrade(),
418            prompt_editor: Some(prompt_editor.clone()),
419            codegen: codegen.clone(),
420            workspace: workspace.clone(),
421            context_store,
422            prompt_store,
423            _subscriptions: vec![
424                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
425                    TerminalInlineAssistant::update_global(cx, |this, cx| {
426                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
427                    })
428                }),
429                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
430                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
431                        CodegenEvent::Finished => {
432                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
433                                assist
434                            } else {
435                                return;
436                            };
437
438                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
439                                if assist.prompt_editor.is_none() {
440                                    if let Some(workspace) = assist.workspace.upgrade() {
441                                        let error =
442                                            format!("Terminal inline assistant error: {}", error);
443                                        workspace.update(cx, |workspace, cx| {
444                                            struct InlineAssistantError;
445
446                                            let id =
447                                                NotificationId::composite::<InlineAssistantError>(
448                                                    assist_id.0,
449                                                );
450
451                                            workspace.show_toast(Toast::new(id, error), cx);
452                                        })
453                                    }
454                                }
455                            }
456
457                            if assist.prompt_editor.is_none() {
458                                this.finish_assist(assist_id, false, false, window, cx);
459                            }
460                        }
461                    })
462                }),
463            ],
464        }
465    }
466}