terminal_inline_assistant.rs

  1use crate::context::attach_context_to_message;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
  7use crate::thread_store::ThreadStore;
  8use anyhow::{Context as _, Result};
  9use client::telemetry::Telemetry;
 10use collections::{HashMap, VecDeque};
 11use editor::{MultiBuffer, actions::SelectAll};
 12use fs::Fs;
 13use gpui::{App, Entity, Focusable, Global, Subscription, UpdateGlobal, WeakEntity};
 14use language::Buffer;
 15use language_model::{
 16    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 17    Role, report_assistant_event,
 18};
 19use project::Project;
 20use prompt_store::PromptBuilder;
 21use std::sync::Arc;
 22use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 23use terminal_view::TerminalView;
 24use ui::prelude::*;
 25use util::ResultExt;
 26use workspace::{Toast, Workspace, notifications::NotificationId};
 27
 28pub fn init(
 29    fs: Arc<dyn Fs>,
 30    prompt_builder: Arc<PromptBuilder>,
 31    telemetry: Arc<Telemetry>,
 32    cx: &mut App,
 33) {
 34    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 35}
 36
 37const DEFAULT_CONTEXT_LINES: usize = 50;
 38const PROMPT_HISTORY_MAX_LEN: usize = 20;
 39
 40pub struct TerminalInlineAssistant {
 41    next_assist_id: TerminalInlineAssistId,
 42    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 43    prompt_history: VecDeque<String>,
 44    telemetry: Option<Arc<Telemetry>>,
 45    fs: Arc<dyn Fs>,
 46    prompt_builder: Arc<PromptBuilder>,
 47}
 48
 49impl Global for TerminalInlineAssistant {}
 50
 51impl TerminalInlineAssistant {
 52    pub fn new(
 53        fs: Arc<dyn Fs>,
 54        prompt_builder: Arc<PromptBuilder>,
 55        telemetry: Arc<Telemetry>,
 56    ) -> Self {
 57        Self {
 58            next_assist_id: TerminalInlineAssistId::default(),
 59            assists: HashMap::default(),
 60            prompt_history: VecDeque::default(),
 61            telemetry: Some(telemetry),
 62            fs,
 63            prompt_builder,
 64        }
 65    }
 66
 67    pub fn assist(
 68        &mut self,
 69        terminal_view: &Entity<TerminalView>,
 70        workspace: WeakEntity<Workspace>,
 71        project: WeakEntity<Project>,
 72        thread_store: Option<WeakEntity<ThreadStore>>,
 73        window: &mut Window,
 74        cx: &mut App,
 75    ) {
 76        let terminal = terminal_view.read(cx).terminal().clone();
 77        let assist_id = self.next_assist_id.post_inc();
 78        let prompt_buffer =
 79            cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx));
 80        let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
 81        let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 82
 83        let prompt_editor = cx.new(|cx| {
 84            PromptEditor::new_terminal(
 85                assist_id,
 86                self.prompt_history.clone(),
 87                prompt_buffer.clone(),
 88                codegen,
 89                self.fs.clone(),
 90                context_store.clone(),
 91                workspace.clone(),
 92                thread_store.clone(),
 93                window,
 94                cx,
 95            )
 96        });
 97        let prompt_editor_render = prompt_editor.clone();
 98        let block = terminal_view::BlockProperties {
 99            height: 2,
100            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
101        };
102        terminal_view.update(cx, |terminal_view, cx| {
103            terminal_view.set_block_below_cursor(block, window, cx);
104        });
105
106        let terminal_assistant = TerminalInlineAssist::new(
107            assist_id,
108            terminal_view,
109            prompt_editor,
110            workspace.clone(),
111            context_store,
112            window,
113            cx,
114        );
115
116        self.assists.insert(assist_id, terminal_assistant);
117
118        self.focus_assist(assist_id, window, cx);
119    }
120
121    fn focus_assist(
122        &mut self,
123        assist_id: TerminalInlineAssistId,
124        window: &mut Window,
125        cx: &mut App,
126    ) {
127        let assist = &self.assists[&assist_id];
128        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
129            prompt_editor.update(cx, |this, cx| {
130                this.editor.update(cx, |editor, cx| {
131                    window.focus(&editor.focus_handle(cx));
132                    editor.select_all(&SelectAll, window, cx);
133                });
134            });
135        }
136    }
137
138    fn handle_prompt_editor_event(
139        &mut self,
140        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
141        event: &PromptEditorEvent,
142        window: &mut Window,
143        cx: &mut App,
144    ) {
145        let assist_id = prompt_editor.read(cx).id();
146        match event {
147            PromptEditorEvent::StartRequested => {
148                self.start_assist(assist_id, cx);
149            }
150            PromptEditorEvent::StopRequested => {
151                self.stop_assist(assist_id, cx);
152            }
153            PromptEditorEvent::ConfirmRequested { execute } => {
154                self.finish_assist(assist_id, false, *execute, window, cx);
155            }
156            PromptEditorEvent::CancelRequested => {
157                self.finish_assist(assist_id, true, false, window, cx);
158            }
159            PromptEditorEvent::DismissRequested => {
160                self.dismiss_assist(assist_id, window, cx);
161            }
162            PromptEditorEvent::Resized { height_in_lines } => {
163                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
164            }
165        }
166    }
167
168    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
169        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
170            assist
171        } else {
172            return;
173        };
174
175        let Some(user_prompt) = assist
176            .prompt_editor
177            .as_ref()
178            .map(|editor| editor.read(cx).prompt(cx))
179        else {
180            return;
181        };
182
183        self.prompt_history.retain(|prompt| *prompt != user_prompt);
184        self.prompt_history.push_back(user_prompt.clone());
185        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
186            self.prompt_history.pop_front();
187        }
188
189        assist
190            .terminal
191            .update(cx, |terminal, cx| {
192                terminal
193                    .terminal()
194                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
195            })
196            .log_err();
197
198        let codegen = assist.codegen.clone();
199        let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
200            return;
201        };
202
203        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
204    }
205
206    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
207        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
208            assist
209        } else {
210            return;
211        };
212
213        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
214    }
215
216    fn request_for_inline_assist(
217        &self,
218        assist_id: TerminalInlineAssistId,
219        cx: &mut App,
220    ) -> Result<LanguageModelRequest> {
221        let assist = self.assists.get(&assist_id).context("invalid assist")?;
222
223        let shell = std::env::var("SHELL").ok();
224        let (latest_output, working_directory) = assist
225            .terminal
226            .update(cx, |terminal, cx| {
227                let terminal = terminal.entity().read(cx);
228                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
229                let working_directory = terminal
230                    .working_directory()
231                    .map(|path| path.to_string_lossy().to_string());
232                (latest_output, working_directory)
233            })
234            .ok()
235            .unwrap_or_default();
236
237        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
238            &assist
239                .prompt_editor
240                .clone()
241                .context("invalid assist")?
242                .read(cx)
243                .prompt(cx),
244            shell.as_deref(),
245            working_directory.as_deref(),
246            &latest_output,
247        )?;
248
249        let mut request_message = LanguageModelRequestMessage {
250            role: Role::User,
251            content: vec![],
252            cache: false,
253        };
254
255        attach_context_to_message(
256            &mut request_message,
257            assist.context_store.read(cx).context().iter(),
258            cx,
259        );
260
261        request_message.content.push(prompt.into());
262
263        Ok(LanguageModelRequest {
264            messages: vec![request_message],
265            tools: Vec::new(),
266            stop: Vec::new(),
267            temperature: None,
268        })
269    }
270
271    fn finish_assist(
272        &mut self,
273        assist_id: TerminalInlineAssistId,
274        undo: bool,
275        execute: bool,
276        window: &mut Window,
277        cx: &mut App,
278    ) {
279        self.dismiss_assist(assist_id, window, cx);
280
281        if let Some(assist) = self.assists.remove(&assist_id) {
282            assist
283                .terminal
284                .update(cx, |this, cx| {
285                    this.clear_block_below_cursor(cx);
286                    this.focus_handle(cx).focus(window);
287                })
288                .log_err();
289
290            if let Some(ConfiguredModel { model, .. }) =
291                LanguageModelRegistry::read_global(cx).inline_assistant_model()
292            {
293                let codegen = assist.codegen.read(cx);
294                let executor = cx.background_executor().clone();
295                report_assistant_event(
296                    AssistantEventData {
297                        conversation_id: None,
298                        kind: AssistantKind::InlineTerminal,
299                        message_id: codegen.message_id.clone(),
300                        phase: if undo {
301                            AssistantPhase::Rejected
302                        } else {
303                            AssistantPhase::Accepted
304                        },
305                        model: model.telemetry_id(),
306                        model_provider: model.provider_id().to_string(),
307                        response_latency: None,
308                        error_message: None,
309                        language_name: None,
310                    },
311                    codegen.telemetry.clone(),
312                    cx.http_client(),
313                    model.api_key(cx),
314                    &executor,
315                );
316            }
317
318            assist.codegen.update(cx, |codegen, cx| {
319                if undo {
320                    codegen.undo(cx);
321                } else if execute {
322                    codegen.complete(cx);
323                }
324            });
325        }
326    }
327
328    fn dismiss_assist(
329        &mut self,
330        assist_id: TerminalInlineAssistId,
331        window: &mut Window,
332        cx: &mut App,
333    ) -> bool {
334        let Some(assist) = self.assists.get_mut(&assist_id) else {
335            return false;
336        };
337        if assist.prompt_editor.is_none() {
338            return false;
339        }
340        assist.prompt_editor = None;
341        assist
342            .terminal
343            .update(cx, |this, cx| {
344                this.clear_block_below_cursor(cx);
345                this.focus_handle(cx).focus(window);
346            })
347            .is_ok()
348    }
349
350    fn insert_prompt_editor_into_terminal(
351        &mut self,
352        assist_id: TerminalInlineAssistId,
353        height: u8,
354        window: &mut Window,
355        cx: &mut App,
356    ) {
357        if let Some(assist) = self.assists.get_mut(&assist_id) {
358            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
359                assist
360                    .terminal
361                    .update(cx, |terminal, cx| {
362                        terminal.clear_block_below_cursor(cx);
363                        let block = terminal_view::BlockProperties {
364                            height,
365                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
366                        };
367                        terminal.set_block_below_cursor(block, window, cx);
368                    })
369                    .log_err();
370            }
371        }
372    }
373}
374
375struct TerminalInlineAssist {
376    terminal: WeakEntity<TerminalView>,
377    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
378    codegen: Entity<TerminalCodegen>,
379    workspace: WeakEntity<Workspace>,
380    context_store: Entity<ContextStore>,
381    _subscriptions: Vec<Subscription>,
382}
383
384impl TerminalInlineAssist {
385    pub fn new(
386        assist_id: TerminalInlineAssistId,
387        terminal: &Entity<TerminalView>,
388        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
389        workspace: WeakEntity<Workspace>,
390        context_store: Entity<ContextStore>,
391        window: &mut Window,
392        cx: &mut App,
393    ) -> Self {
394        let codegen = prompt_editor.read(cx).codegen().clone();
395        Self {
396            terminal: terminal.downgrade(),
397            prompt_editor: Some(prompt_editor.clone()),
398            codegen: codegen.clone(),
399            workspace: workspace.clone(),
400            context_store,
401            _subscriptions: vec![
402                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
403                    TerminalInlineAssistant::update_global(cx, |this, cx| {
404                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
405                    })
406                }),
407                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
408                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
409                        CodegenEvent::Finished => {
410                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
411                                assist
412                            } else {
413                                return;
414                            };
415
416                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
417                                if assist.prompt_editor.is_none() {
418                                    if let Some(workspace) = assist.workspace.upgrade() {
419                                        let error =
420                                            format!("Terminal inline assistant error: {}", error);
421                                        workspace.update(cx, |workspace, cx| {
422                                            struct InlineAssistantError;
423
424                                            let id =
425                                                NotificationId::composite::<InlineAssistantError>(
426                                                    assist_id.0,
427                                                );
428
429                                            workspace.show_toast(Toast::new(id, error), cx);
430                                        })
431                                    }
432                                }
433                            }
434
435                            if assist.prompt_editor.is_none() {
436                                this.finish_assist(assist_id, false, false, window, cx);
437                            }
438                        }
439                    })
440                }),
441            ],
442        }
443    }
444}