terminal_inline_assistant.rs

  1use crate::context::load_context;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
  7use crate::thread_store::{TextThreadStore, ThreadStore};
  8use anyhow::{Context as _, Result};
  9use assistant_settings::AssistantSettings;
 10use client::telemetry::Telemetry;
 11use collections::{HashMap, VecDeque};
 12use editor::{MultiBuffer, actions::SelectAll};
 13use fs::Fs;
 14use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
 15use language::Buffer;
 16use language_model::{
 17    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 18    Role, report_assistant_event,
 19};
 20use project::Project;
 21use prompt_store::{PromptBuilder, PromptStore};
 22use std::sync::Arc;
 23use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 24use terminal_view::TerminalView;
 25use ui::prelude::*;
 26use util::ResultExt;
 27use workspace::{Toast, Workspace, notifications::NotificationId};
 28
 29pub fn init(
 30    fs: Arc<dyn Fs>,
 31    prompt_builder: Arc<PromptBuilder>,
 32    telemetry: Arc<Telemetry>,
 33    cx: &mut App,
 34) {
 35    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 36}
 37
 38const DEFAULT_CONTEXT_LINES: usize = 50;
 39const PROMPT_HISTORY_MAX_LEN: usize = 20;
 40
 41pub struct TerminalInlineAssistant {
 42    next_assist_id: TerminalInlineAssistId,
 43    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 44    prompt_history: VecDeque<String>,
 45    telemetry: Option<Arc<Telemetry>>,
 46    fs: Arc<dyn Fs>,
 47    prompt_builder: Arc<PromptBuilder>,
 48}
 49
 50impl Global for TerminalInlineAssistant {}
 51
 52impl TerminalInlineAssistant {
 53    pub fn new(
 54        fs: Arc<dyn Fs>,
 55        prompt_builder: Arc<PromptBuilder>,
 56        telemetry: Arc<Telemetry>,
 57    ) -> Self {
 58        Self {
 59            next_assist_id: TerminalInlineAssistId::default(),
 60            assists: HashMap::default(),
 61            prompt_history: VecDeque::default(),
 62            telemetry: Some(telemetry),
 63            fs,
 64            prompt_builder,
 65        }
 66    }
 67
 68    pub fn assist(
 69        &mut self,
 70        terminal_view: &Entity<TerminalView>,
 71        workspace: WeakEntity<Workspace>,
 72        project: WeakEntity<Project>,
 73        prompt_store: Option<Entity<PromptStore>>,
 74        thread_store: Option<WeakEntity<ThreadStore>>,
 75        text_thread_store: Option<WeakEntity<TextThreadStore>>,
 76        initial_prompt: Option<String>,
 77        window: &mut Window,
 78        cx: &mut App,
 79    ) {
 80        let terminal = terminal_view.read(cx).terminal().clone();
 81        let assist_id = self.next_assist_id.post_inc();
 82        let prompt_buffer = cx.new(|cx| {
 83            MultiBuffer::singleton(
 84                cx.new(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx)),
 85                cx,
 86            )
 87        });
 88        let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
 89        let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 90
 91        let prompt_editor = cx.new(|cx| {
 92            PromptEditor::new_terminal(
 93                assist_id,
 94                self.prompt_history.clone(),
 95                prompt_buffer.clone(),
 96                codegen,
 97                self.fs.clone(),
 98                context_store.clone(),
 99                workspace.clone(),
100                thread_store.clone(),
101                text_thread_store.clone(),
102                window,
103                cx,
104            )
105        });
106        let prompt_editor_render = prompt_editor.clone();
107        let block = terminal_view::BlockProperties {
108            height: 2,
109            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
110        };
111        terminal_view.update(cx, |terminal_view, cx| {
112            terminal_view.set_block_below_cursor(block, window, cx);
113        });
114
115        let terminal_assistant = TerminalInlineAssist::new(
116            assist_id,
117            terminal_view,
118            prompt_editor,
119            workspace.clone(),
120            context_store,
121            prompt_store,
122            window,
123            cx,
124        );
125
126        self.assists.insert(assist_id, terminal_assistant);
127
128        self.focus_assist(assist_id, window, cx);
129    }
130
131    fn focus_assist(
132        &mut self,
133        assist_id: TerminalInlineAssistId,
134        window: &mut Window,
135        cx: &mut App,
136    ) {
137        let assist = &self.assists[&assist_id];
138        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
139            prompt_editor.update(cx, |this, cx| {
140                this.editor.update(cx, |editor, cx| {
141                    window.focus(&editor.focus_handle(cx));
142                    editor.select_all(&SelectAll, window, cx);
143                });
144            });
145        }
146    }
147
148    fn handle_prompt_editor_event(
149        &mut self,
150        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
151        event: &PromptEditorEvent,
152        window: &mut Window,
153        cx: &mut App,
154    ) {
155        let assist_id = prompt_editor.read(cx).id();
156        match event {
157            PromptEditorEvent::StartRequested => {
158                self.start_assist(assist_id, cx);
159            }
160            PromptEditorEvent::StopRequested => {
161                self.stop_assist(assist_id, cx);
162            }
163            PromptEditorEvent::ConfirmRequested { execute } => {
164                self.finish_assist(assist_id, false, *execute, window, cx);
165            }
166            PromptEditorEvent::CancelRequested => {
167                self.finish_assist(assist_id, true, false, window, cx);
168            }
169            PromptEditorEvent::DismissRequested => {
170                self.dismiss_assist(assist_id, window, cx);
171            }
172            PromptEditorEvent::Resized { height_in_lines } => {
173                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
174            }
175        }
176    }
177
178    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
179        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
180            assist
181        } else {
182            return;
183        };
184
185        let Some(user_prompt) = assist
186            .prompt_editor
187            .as_ref()
188            .map(|editor| editor.read(cx).prompt(cx))
189        else {
190            return;
191        };
192
193        self.prompt_history.retain(|prompt| *prompt != user_prompt);
194        self.prompt_history.push_back(user_prompt.clone());
195        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
196            self.prompt_history.pop_front();
197        }
198
199        assist
200            .terminal
201            .update(cx, |terminal, cx| {
202                terminal
203                    .terminal()
204                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
205            })
206            .log_err();
207
208        let codegen = assist.codegen.clone();
209        let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
210            return;
211        };
212
213        codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
214    }
215
216    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
217        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
218            assist
219        } else {
220            return;
221        };
222
223        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
224    }
225
226    fn request_for_inline_assist(
227        &self,
228        assist_id: TerminalInlineAssistId,
229        cx: &mut App,
230    ) -> Result<Task<LanguageModelRequest>> {
231        let assist = self.assists.get(&assist_id).context("invalid assist")?;
232
233        let shell = std::env::var("SHELL").ok();
234        let (latest_output, working_directory) = assist
235            .terminal
236            .update(cx, |terminal, cx| {
237                let terminal = terminal.entity().read(cx);
238                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
239                let working_directory = terminal
240                    .working_directory()
241                    .map(|path| path.to_string_lossy().to_string());
242                (latest_output, working_directory)
243            })
244            .ok()
245            .unwrap_or_default();
246
247        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
248            &assist
249                .prompt_editor
250                .clone()
251                .context("invalid assist")?
252                .read(cx)
253                .prompt(cx),
254            shell.as_deref(),
255            working_directory.as_deref(),
256            &latest_output,
257        )?;
258
259        let contexts = assist
260            .context_store
261            .read(cx)
262            .context()
263            .cloned()
264            .collect::<Vec<_>>();
265        let context_load_task = assist.workspace.update(cx, |workspace, cx| {
266            let project = workspace.project();
267            load_context(contexts, project, &assist.prompt_store, cx)
268        })?;
269
270        let ConfiguredModel { model, .. } = LanguageModelRegistry::read_global(cx)
271            .inline_assistant_model()
272            .context("No inline assistant model")?;
273
274        let temperature = AssistantSettings::temperature_for_model(&model, cx);
275
276        Ok(cx.background_spawn(async move {
277            let mut request_message = LanguageModelRequestMessage {
278                role: Role::User,
279                content: vec![],
280                cache: false,
281            };
282
283            context_load_task
284                .await
285                .loaded_context
286                .add_to_request_message(&mut request_message);
287
288            request_message.content.push(prompt.into());
289
290            LanguageModelRequest {
291                thread_id: None,
292                prompt_id: None,
293                mode: None,
294                messages: vec![request_message],
295                tools: Vec::new(),
296                stop: Vec::new(),
297                temperature,
298            }
299        }))
300    }
301
302    fn finish_assist(
303        &mut self,
304        assist_id: TerminalInlineAssistId,
305        undo: bool,
306        execute: bool,
307        window: &mut Window,
308        cx: &mut App,
309    ) {
310        self.dismiss_assist(assist_id, window, cx);
311
312        if let Some(assist) = self.assists.remove(&assist_id) {
313            assist
314                .terminal
315                .update(cx, |this, cx| {
316                    this.clear_block_below_cursor(cx);
317                    this.focus_handle(cx).focus(window);
318                })
319                .log_err();
320
321            if let Some(ConfiguredModel { model, .. }) =
322                LanguageModelRegistry::read_global(cx).inline_assistant_model()
323            {
324                let codegen = assist.codegen.read(cx);
325                let executor = cx.background_executor().clone();
326                report_assistant_event(
327                    AssistantEventData {
328                        conversation_id: None,
329                        kind: AssistantKind::InlineTerminal,
330                        message_id: codegen.message_id.clone(),
331                        phase: if undo {
332                            AssistantPhase::Rejected
333                        } else {
334                            AssistantPhase::Accepted
335                        },
336                        model: model.telemetry_id(),
337                        model_provider: model.provider_id().to_string(),
338                        response_latency: None,
339                        error_message: None,
340                        language_name: None,
341                    },
342                    codegen.telemetry.clone(),
343                    cx.http_client(),
344                    model.api_key(cx),
345                    &executor,
346                );
347            }
348
349            assist.codegen.update(cx, |codegen, cx| {
350                if undo {
351                    codegen.undo(cx);
352                } else if execute {
353                    codegen.complete(cx);
354                }
355            });
356        }
357    }
358
359    fn dismiss_assist(
360        &mut self,
361        assist_id: TerminalInlineAssistId,
362        window: &mut Window,
363        cx: &mut App,
364    ) -> bool {
365        let Some(assist) = self.assists.get_mut(&assist_id) else {
366            return false;
367        };
368        if assist.prompt_editor.is_none() {
369            return false;
370        }
371        assist.prompt_editor = None;
372        assist
373            .terminal
374            .update(cx, |this, cx| {
375                this.clear_block_below_cursor(cx);
376                this.focus_handle(cx).focus(window);
377            })
378            .is_ok()
379    }
380
381    fn insert_prompt_editor_into_terminal(
382        &mut self,
383        assist_id: TerminalInlineAssistId,
384        height: u8,
385        window: &mut Window,
386        cx: &mut App,
387    ) {
388        if let Some(assist) = self.assists.get_mut(&assist_id) {
389            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
390                assist
391                    .terminal
392                    .update(cx, |terminal, cx| {
393                        terminal.clear_block_below_cursor(cx);
394                        let block = terminal_view::BlockProperties {
395                            height,
396                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
397                        };
398                        terminal.set_block_below_cursor(block, window, cx);
399                    })
400                    .log_err();
401            }
402        }
403    }
404}
405
406struct TerminalInlineAssist {
407    terminal: WeakEntity<TerminalView>,
408    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
409    codegen: Entity<TerminalCodegen>,
410    workspace: WeakEntity<Workspace>,
411    context_store: Entity<ContextStore>,
412    prompt_store: Option<Entity<PromptStore>>,
413    _subscriptions: Vec<Subscription>,
414}
415
416impl TerminalInlineAssist {
417    pub fn new(
418        assist_id: TerminalInlineAssistId,
419        terminal: &Entity<TerminalView>,
420        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
421        workspace: WeakEntity<Workspace>,
422        context_store: Entity<ContextStore>,
423        prompt_store: Option<Entity<PromptStore>>,
424        window: &mut Window,
425        cx: &mut App,
426    ) -> Self {
427        let codegen = prompt_editor.read(cx).codegen().clone();
428        Self {
429            terminal: terminal.downgrade(),
430            prompt_editor: Some(prompt_editor.clone()),
431            codegen: codegen.clone(),
432            workspace: workspace.clone(),
433            context_store,
434            prompt_store,
435            _subscriptions: vec![
436                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
437                    TerminalInlineAssistant::update_global(cx, |this, cx| {
438                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
439                    })
440                }),
441                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
442                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
443                        CodegenEvent::Finished => {
444                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
445                                assist
446                            } else {
447                                return;
448                            };
449
450                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
451                                if assist.prompt_editor.is_none() {
452                                    if let Some(workspace) = assist.workspace.upgrade() {
453                                        let error =
454                                            format!("Terminal inline assistant error: {}", error);
455                                        workspace.update(cx, |workspace, cx| {
456                                            struct InlineAssistantError;
457
458                                            let id =
459                                                NotificationId::composite::<InlineAssistantError>(
460                                                    assist_id.0,
461                                                );
462
463                                            workspace.show_toast(Toast::new(id, error), cx);
464                                        })
465                                    }
466                                }
467                            }
468
469                            if assist.prompt_editor.is_none() {
470                                this.finish_assist(assist_id, false, false, window, cx);
471                            }
472                        }
473                    })
474                }),
475            ],
476        }
477    }
478}