terminal_inline_assistant.rs

  1use crate::context::attach_context_to_message;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::prompts::PromptBuilder;
  7use crate::terminal_codegen::{CodegenEvent, TerminalCodegen, CLEAR_INPUT};
  8use crate::thread_store::ThreadStore;
  9use anyhow::{Context as _, Result};
 10use client::telemetry::Telemetry;
 11use collections::{HashMap, VecDeque};
 12use editor::{actions::SelectAll, MultiBuffer};
 13use fs::Fs;
 14use gpui::{
 15    AppContext, Context, FocusableView, Global, Model, Subscription, UpdateGlobal, View, WeakModel,
 16    WeakView,
 17};
 18use language::Buffer;
 19use language_model::{
 20    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
 21};
 22use language_models::report_assistant_event;
 23use std::sync::Arc;
 24use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
 25use terminal_view::TerminalView;
 26use ui::prelude::*;
 27use util::ResultExt;
 28use workspace::{notifications::NotificationId, Toast, Workspace};
 29
 30pub fn init(
 31    fs: Arc<dyn Fs>,
 32    prompt_builder: Arc<PromptBuilder>,
 33    telemetry: Arc<Telemetry>,
 34    cx: &mut AppContext,
 35) {
 36    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 37}
 38
 39const DEFAULT_CONTEXT_LINES: usize = 50;
 40const PROMPT_HISTORY_MAX_LEN: usize = 20;
 41
 42pub struct TerminalInlineAssistant {
 43    next_assist_id: TerminalInlineAssistId,
 44    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 45    prompt_history: VecDeque<String>,
 46    telemetry: Option<Arc<Telemetry>>,
 47    fs: Arc<dyn Fs>,
 48    prompt_builder: Arc<PromptBuilder>,
 49}
 50
 51impl Global for TerminalInlineAssistant {}
 52
 53impl TerminalInlineAssistant {
 54    pub fn new(
 55        fs: Arc<dyn Fs>,
 56        prompt_builder: Arc<PromptBuilder>,
 57        telemetry: Arc<Telemetry>,
 58    ) -> Self {
 59        Self {
 60            next_assist_id: TerminalInlineAssistId::default(),
 61            assists: HashMap::default(),
 62            prompt_history: VecDeque::default(),
 63            telemetry: Some(telemetry),
 64            fs,
 65            prompt_builder,
 66        }
 67    }
 68
 69    pub fn assist(
 70        &mut self,
 71        terminal_view: &View<TerminalView>,
 72        workspace: WeakView<Workspace>,
 73        thread_store: Option<WeakModel<ThreadStore>>,
 74        cx: &mut WindowContext,
 75    ) {
 76        let terminal = terminal_view.read(cx).terminal().clone();
 77        let assist_id = self.next_assist_id.post_inc();
 78        let prompt_buffer = cx.new_model(|cx| {
 79            MultiBuffer::singleton(cx.new_model(|cx| Buffer::local(String::new(), cx)), cx)
 80        });
 81        let context_store = cx.new_model(|_cx| ContextStore::new(workspace.clone()));
 82        let codegen = cx.new_model(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 83
 84        let prompt_editor = cx.new_view(|cx| {
 85            PromptEditor::new_terminal(
 86                assist_id,
 87                self.prompt_history.clone(),
 88                prompt_buffer.clone(),
 89                codegen,
 90                self.fs.clone(),
 91                context_store.clone(),
 92                workspace.clone(),
 93                thread_store.clone(),
 94                cx,
 95            )
 96        });
 97        let prompt_editor_render = prompt_editor.clone();
 98        let block = terminal_view::BlockProperties {
 99            height: 2,
100            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
101        };
102        terminal_view.update(cx, |terminal_view, cx| {
103            terminal_view.set_block_below_cursor(block, cx);
104        });
105
106        let terminal_assistant = TerminalInlineAssist::new(
107            assist_id,
108            terminal_view,
109            prompt_editor,
110            workspace.clone(),
111            context_store,
112            cx,
113        );
114
115        self.assists.insert(assist_id, terminal_assistant);
116
117        self.focus_assist(assist_id, cx);
118    }
119
120    fn focus_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut WindowContext) {
121        let assist = &self.assists[&assist_id];
122        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
123            prompt_editor.update(cx, |this, cx| {
124                this.editor.update(cx, |editor, cx| {
125                    editor.focus(cx);
126                    editor.select_all(&SelectAll, cx);
127                });
128            });
129        }
130    }
131
132    fn handle_prompt_editor_event(
133        &mut self,
134        prompt_editor: View<PromptEditor<TerminalCodegen>>,
135        event: &PromptEditorEvent,
136        cx: &mut WindowContext,
137    ) {
138        let assist_id = prompt_editor.read(cx).id();
139        match event {
140            PromptEditorEvent::StartRequested => {
141                self.start_assist(assist_id, cx);
142            }
143            PromptEditorEvent::StopRequested => {
144                self.stop_assist(assist_id, cx);
145            }
146            PromptEditorEvent::ConfirmRequested { execute } => {
147                self.finish_assist(assist_id, false, *execute, cx);
148            }
149            PromptEditorEvent::CancelRequested => {
150                self.finish_assist(assist_id, true, false, cx);
151            }
152            PromptEditorEvent::DismissRequested => {
153                self.dismiss_assist(assist_id, cx);
154            }
155            PromptEditorEvent::Resized { height_in_lines } => {
156                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, cx);
157            }
158        }
159    }
160
161    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut WindowContext) {
162        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
163            assist
164        } else {
165            return;
166        };
167
168        let Some(user_prompt) = assist
169            .prompt_editor
170            .as_ref()
171            .map(|editor| editor.read(cx).prompt(cx))
172        else {
173            return;
174        };
175
176        self.prompt_history.retain(|prompt| *prompt != user_prompt);
177        self.prompt_history.push_back(user_prompt.clone());
178        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
179            self.prompt_history.pop_front();
180        }
181
182        assist
183            .terminal
184            .update(cx, |terminal, cx| {
185                terminal
186                    .terminal()
187                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
188            })
189            .log_err();
190
191        let codegen = assist.codegen.clone();
192        let Some(request) = self.request_for_inline_assist(assist_id, cx).log_err() else {
193            return;
194        };
195
196        codegen.update(cx, |codegen, cx| codegen.start(request, cx));
197    }
198
199    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut WindowContext) {
200        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
201            assist
202        } else {
203            return;
204        };
205
206        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
207    }
208
209    fn request_for_inline_assist(
210        &self,
211        assist_id: TerminalInlineAssistId,
212        cx: &mut WindowContext,
213    ) -> Result<LanguageModelRequest> {
214        let assist = self.assists.get(&assist_id).context("invalid assist")?;
215
216        let shell = std::env::var("SHELL").ok();
217        let (latest_output, working_directory) = assist
218            .terminal
219            .update(cx, |terminal, cx| {
220                let terminal = terminal.model().read(cx);
221                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
222                let working_directory = terminal
223                    .working_directory()
224                    .map(|path| path.to_string_lossy().to_string());
225                (latest_output, working_directory)
226            })
227            .ok()
228            .unwrap_or_default();
229
230        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
231            &assist
232                .prompt_editor
233                .clone()
234                .context("invalid assist")?
235                .read(cx)
236                .prompt(cx),
237            shell.as_deref(),
238            working_directory.as_deref(),
239            &latest_output,
240        )?;
241
242        let mut request_message = LanguageModelRequestMessage {
243            role: Role::User,
244            content: vec![],
245            cache: false,
246        };
247
248        attach_context_to_message(
249            &mut request_message,
250            assist.context_store.read(cx).snapshot(cx),
251        );
252
253        request_message.content.push(prompt.into());
254
255        Ok(LanguageModelRequest {
256            messages: vec![request_message],
257            tools: Vec::new(),
258            stop: Vec::new(),
259            temperature: None,
260        })
261    }
262
263    fn finish_assist(
264        &mut self,
265        assist_id: TerminalInlineAssistId,
266        undo: bool,
267        execute: bool,
268        cx: &mut WindowContext,
269    ) {
270        self.dismiss_assist(assist_id, cx);
271
272        if let Some(assist) = self.assists.remove(&assist_id) {
273            assist
274                .terminal
275                .update(cx, |this, cx| {
276                    this.clear_block_below_cursor(cx);
277                    this.focus_handle(cx).focus(cx);
278                })
279                .log_err();
280
281            if let Some(model) = LanguageModelRegistry::read_global(cx).active_model() {
282                let codegen = assist.codegen.read(cx);
283                let executor = cx.background_executor().clone();
284                report_assistant_event(
285                    AssistantEvent {
286                        conversation_id: None,
287                        kind: AssistantKind::InlineTerminal,
288                        message_id: codegen.message_id.clone(),
289                        phase: if undo {
290                            AssistantPhase::Rejected
291                        } else {
292                            AssistantPhase::Accepted
293                        },
294                        model: model.telemetry_id(),
295                        model_provider: model.provider_id().to_string(),
296                        response_latency: None,
297                        error_message: None,
298                        language_name: None,
299                    },
300                    codegen.telemetry.clone(),
301                    cx.http_client(),
302                    model.api_key(cx),
303                    &executor,
304                );
305            }
306
307            assist.codegen.update(cx, |codegen, cx| {
308                if undo {
309                    codegen.undo(cx);
310                } else if execute {
311                    codegen.complete(cx);
312                }
313            });
314        }
315    }
316
317    fn dismiss_assist(
318        &mut self,
319        assist_id: TerminalInlineAssistId,
320        cx: &mut WindowContext,
321    ) -> bool {
322        let Some(assist) = self.assists.get_mut(&assist_id) else {
323            return false;
324        };
325        if assist.prompt_editor.is_none() {
326            return false;
327        }
328        assist.prompt_editor = None;
329        assist
330            .terminal
331            .update(cx, |this, cx| {
332                this.clear_block_below_cursor(cx);
333                this.focus_handle(cx).focus(cx);
334            })
335            .is_ok()
336    }
337
338    fn insert_prompt_editor_into_terminal(
339        &mut self,
340        assist_id: TerminalInlineAssistId,
341        height: u8,
342        cx: &mut WindowContext,
343    ) {
344        if let Some(assist) = self.assists.get_mut(&assist_id) {
345            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
346                assist
347                    .terminal
348                    .update(cx, |terminal, cx| {
349                        terminal.clear_block_below_cursor(cx);
350                        let block = terminal_view::BlockProperties {
351                            height,
352                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
353                        };
354                        terminal.set_block_below_cursor(block, cx);
355                    })
356                    .log_err();
357            }
358        }
359    }
360}
361
362struct TerminalInlineAssist {
363    terminal: WeakView<TerminalView>,
364    prompt_editor: Option<View<PromptEditor<TerminalCodegen>>>,
365    codegen: Model<TerminalCodegen>,
366    workspace: WeakView<Workspace>,
367    context_store: Model<ContextStore>,
368    _subscriptions: Vec<Subscription>,
369}
370
371impl TerminalInlineAssist {
372    pub fn new(
373        assist_id: TerminalInlineAssistId,
374        terminal: &View<TerminalView>,
375        prompt_editor: View<PromptEditor<TerminalCodegen>>,
376        workspace: WeakView<Workspace>,
377        context_store: Model<ContextStore>,
378        cx: &mut WindowContext,
379    ) -> Self {
380        let codegen = prompt_editor.read(cx).codegen().clone();
381        Self {
382            terminal: terminal.downgrade(),
383            prompt_editor: Some(prompt_editor.clone()),
384            codegen: codegen.clone(),
385            workspace: workspace.clone(),
386            context_store,
387            _subscriptions: vec![
388                cx.subscribe(&prompt_editor, |prompt_editor, event, cx| {
389                    TerminalInlineAssistant::update_global(cx, |this, cx| {
390                        this.handle_prompt_editor_event(prompt_editor, event, cx)
391                    })
392                }),
393                cx.subscribe(&codegen, move |codegen, event, cx| {
394                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
395                        CodegenEvent::Finished => {
396                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
397                                assist
398                            } else {
399                                return;
400                            };
401
402                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
403                                if assist.prompt_editor.is_none() {
404                                    if let Some(workspace) = assist.workspace.upgrade() {
405                                        let error =
406                                            format!("Terminal inline assistant error: {}", error);
407                                        workspace.update(cx, |workspace, cx| {
408                                            struct InlineAssistantError;
409
410                                            let id =
411                                                NotificationId::composite::<InlineAssistantError>(
412                                                    assist_id.0,
413                                                );
414
415                                            workspace.show_toast(Toast::new(id, error), cx);
416                                        })
417                                    }
418                                }
419                            }
420
421                            if assist.prompt_editor.is_none() {
422                                this.finish_assist(assist_id, false, false, cx);
423                            }
424                        }
425                    })
426                }),
427            ],
428        }
429    }
430}