terminal_inline_assistant.rs

  1use crate::context::load_context;
  2use crate::context_store::ContextStore;
  3use crate::inline_prompt_editor::{
  4    CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  5};
  6use crate::terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen};
  7use crate::thread_store::{TextThreadStore, ThreadStore};
  8use agent_settings::AgentSettings;
  9use anyhow::{Context as _, Result};
 10use client::telemetry::Telemetry;
 11use collections::{HashMap, VecDeque};
 12use editor::{MultiBuffer, actions::SelectAll};
 13use fs::Fs;
 14use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
 15use language::Buffer;
 16use language_model::{
 17    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 18    Role, report_assistant_event,
 19};
 20use project::Project;
 21use prompt_store::{PromptBuilder, PromptStore};
 22use std::sync::Arc;
 23use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 24use terminal_view::TerminalView;
 25use ui::prelude::*;
 26use util::ResultExt;
 27use workspace::{Toast, Workspace, notifications::NotificationId};
 28use zed_llm_client::CompletionIntent;
 29
 30pub fn init(
 31    fs: Arc<dyn Fs>,
 32    prompt_builder: Arc<PromptBuilder>,
 33    telemetry: Arc<Telemetry>,
 34    cx: &mut App,
 35) {
 36    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder, telemetry));
 37}
 38
 39const DEFAULT_CONTEXT_LINES: usize = 50;
 40const PROMPT_HISTORY_MAX_LEN: usize = 20;
 41
 42pub struct TerminalInlineAssistant {
 43    next_assist_id: TerminalInlineAssistId,
 44    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 45    prompt_history: VecDeque<String>,
 46    telemetry: Option<Arc<Telemetry>>,
 47    fs: Arc<dyn Fs>,
 48    prompt_builder: Arc<PromptBuilder>,
 49}
 50
 51impl Global for TerminalInlineAssistant {}
 52
 53impl TerminalInlineAssistant {
 54    pub fn new(
 55        fs: Arc<dyn Fs>,
 56        prompt_builder: Arc<PromptBuilder>,
 57        telemetry: Arc<Telemetry>,
 58    ) -> Self {
 59        Self {
 60            next_assist_id: TerminalInlineAssistId::default(),
 61            assists: HashMap::default(),
 62            prompt_history: VecDeque::default(),
 63            telemetry: Some(telemetry),
 64            fs,
 65            prompt_builder,
 66        }
 67    }
 68
 69    pub fn assist(
 70        &mut self,
 71        terminal_view: &Entity<TerminalView>,
 72        workspace: WeakEntity<Workspace>,
 73        project: WeakEntity<Project>,
 74        prompt_store: Option<Entity<PromptStore>>,
 75        thread_store: Option<WeakEntity<ThreadStore>>,
 76        text_thread_store: Option<WeakEntity<TextThreadStore>>,
 77        initial_prompt: Option<String>,
 78        window: &mut Window,
 79        cx: &mut App,
 80    ) {
 81        let terminal = terminal_view.read(cx).terminal().clone();
 82        let assist_id = self.next_assist_id.post_inc();
 83        let prompt_buffer = cx.new(|cx| {
 84            MultiBuffer::singleton(
 85                cx.new(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx)),
 86                cx,
 87            )
 88        });
 89        let context_store = cx.new(|_cx| ContextStore::new(project, thread_store.clone()));
 90        let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
 91
 92        let prompt_editor = cx.new(|cx| {
 93            PromptEditor::new_terminal(
 94                assist_id,
 95                self.prompt_history.clone(),
 96                prompt_buffer.clone(),
 97                codegen,
 98                self.fs.clone(),
 99                context_store.clone(),
100                workspace.clone(),
101                thread_store.clone(),
102                text_thread_store.clone(),
103                window,
104                cx,
105            )
106        });
107        let prompt_editor_render = prompt_editor.clone();
108        let block = terminal_view::BlockProperties {
109            height: 4,
110            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
111        };
112        terminal_view.update(cx, |terminal_view, cx| {
113            terminal_view.set_block_below_cursor(block, window, cx);
114        });
115
116        let terminal_assistant = TerminalInlineAssist::new(
117            assist_id,
118            terminal_view,
119            prompt_editor,
120            workspace.clone(),
121            context_store,
122            prompt_store,
123            window,
124            cx,
125        );
126
127        self.assists.insert(assist_id, terminal_assistant);
128
129        self.focus_assist(assist_id, window, cx);
130    }
131
132    fn focus_assist(
133        &mut self,
134        assist_id: TerminalInlineAssistId,
135        window: &mut Window,
136        cx: &mut App,
137    ) {
138        let assist = &self.assists[&assist_id];
139        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
140            prompt_editor.update(cx, |this, cx| {
141                this.editor.update(cx, |editor, cx| {
142                    window.focus(&editor.focus_handle(cx));
143                    editor.select_all(&SelectAll, window, cx);
144                });
145            });
146        }
147    }
148
149    fn handle_prompt_editor_event(
150        &mut self,
151        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
152        event: &PromptEditorEvent,
153        window: &mut Window,
154        cx: &mut App,
155    ) {
156        let assist_id = prompt_editor.read(cx).id();
157        match event {
158            PromptEditorEvent::StartRequested => {
159                self.start_assist(assist_id, cx);
160            }
161            PromptEditorEvent::StopRequested => {
162                self.stop_assist(assist_id, cx);
163            }
164            PromptEditorEvent::ConfirmRequested { execute } => {
165                self.finish_assist(assist_id, false, *execute, window, cx);
166            }
167            PromptEditorEvent::CancelRequested => {
168                self.finish_assist(assist_id, true, false, window, cx);
169            }
170            PromptEditorEvent::DismissRequested => {
171                self.dismiss_assist(assist_id, window, cx);
172            }
173            PromptEditorEvent::Resized { height_in_lines } => {
174                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
175            }
176        }
177    }
178
179    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
180        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
181            assist
182        } else {
183            return;
184        };
185
186        let Some(user_prompt) = assist
187            .prompt_editor
188            .as_ref()
189            .map(|editor| editor.read(cx).prompt(cx))
190        else {
191            return;
192        };
193
194        self.prompt_history.retain(|prompt| *prompt != user_prompt);
195        self.prompt_history.push_back(user_prompt);
196        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
197            self.prompt_history.pop_front();
198        }
199
200        assist
201            .terminal
202            .update(cx, |terminal, cx| {
203                terminal
204                    .terminal()
205                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
206            })
207            .log_err();
208
209        let codegen = assist.codegen.clone();
210        let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
211            return;
212        };
213
214        codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
215    }
216
217    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
218        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
219            assist
220        } else {
221            return;
222        };
223
224        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
225    }
226
227    fn request_for_inline_assist(
228        &self,
229        assist_id: TerminalInlineAssistId,
230        cx: &mut App,
231    ) -> Result<Task<LanguageModelRequest>> {
232        let assist = self.assists.get(&assist_id).context("invalid assist")?;
233
234        let shell = std::env::var("SHELL").ok();
235        let (latest_output, working_directory) = assist
236            .terminal
237            .update(cx, |terminal, cx| {
238                let terminal = terminal.entity().read(cx);
239                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
240                let working_directory = terminal
241                    .working_directory()
242                    .map(|path| path.to_string_lossy().to_string());
243                (latest_output, working_directory)
244            })
245            .ok()
246            .unwrap_or_default();
247
248        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
249            &assist
250                .prompt_editor
251                .clone()
252                .context("invalid assist")?
253                .read(cx)
254                .prompt(cx),
255            shell.as_deref(),
256            working_directory.as_deref(),
257            &latest_output,
258        )?;
259
260        let contexts = assist
261            .context_store
262            .read(cx)
263            .context()
264            .cloned()
265            .collect::<Vec<_>>();
266        let context_load_task = assist.workspace.update(cx, |workspace, cx| {
267            let project = workspace.project();
268            load_context(contexts, project, &assist.prompt_store, cx)
269        })?;
270
271        let ConfiguredModel { model, .. } = LanguageModelRegistry::read_global(cx)
272            .inline_assistant_model()
273            .context("No inline assistant model")?;
274
275        let temperature = AgentSettings::temperature_for_model(&model, cx);
276
277        Ok(cx.background_spawn(async move {
278            let mut request_message = LanguageModelRequestMessage {
279                role: Role::User,
280                content: vec![],
281                cache: false,
282            };
283
284            context_load_task
285                .await
286                .loaded_context
287                .add_to_request_message(&mut request_message);
288
289            request_message.content.push(prompt.into());
290
291            LanguageModelRequest {
292                thread_id: None,
293                prompt_id: None,
294                mode: None,
295                intent: Some(CompletionIntent::TerminalInlineAssist),
296                messages: vec![request_message],
297                tools: Vec::new(),
298                tool_choice: None,
299                stop: Vec::new(),
300                temperature,
301            }
302        }))
303    }
304
305    fn finish_assist(
306        &mut self,
307        assist_id: TerminalInlineAssistId,
308        undo: bool,
309        execute: bool,
310        window: &mut Window,
311        cx: &mut App,
312    ) {
313        self.dismiss_assist(assist_id, window, cx);
314
315        if let Some(assist) = self.assists.remove(&assist_id) {
316            assist
317                .terminal
318                .update(cx, |this, cx| {
319                    this.clear_block_below_cursor(cx);
320                    this.focus_handle(cx).focus(window);
321                })
322                .log_err();
323
324            if let Some(ConfiguredModel { model, .. }) =
325                LanguageModelRegistry::read_global(cx).inline_assistant_model()
326            {
327                let codegen = assist.codegen.read(cx);
328                let executor = cx.background_executor().clone();
329                report_assistant_event(
330                    AssistantEventData {
331                        conversation_id: None,
332                        kind: AssistantKind::InlineTerminal,
333                        message_id: codegen.message_id.clone(),
334                        phase: if undo {
335                            AssistantPhase::Rejected
336                        } else {
337                            AssistantPhase::Accepted
338                        },
339                        model: model.telemetry_id(),
340                        model_provider: model.provider_id().to_string(),
341                        response_latency: None,
342                        error_message: None,
343                        language_name: None,
344                    },
345                    codegen.telemetry.clone(),
346                    cx.http_client(),
347                    model.api_key(cx),
348                    &executor,
349                );
350            }
351
352            assist.codegen.update(cx, |codegen, cx| {
353                if undo {
354                    codegen.undo(cx);
355                } else if execute {
356                    codegen.complete(cx);
357                }
358            });
359        }
360    }
361
362    fn dismiss_assist(
363        &mut self,
364        assist_id: TerminalInlineAssistId,
365        window: &mut Window,
366        cx: &mut App,
367    ) -> bool {
368        let Some(assist) = self.assists.get_mut(&assist_id) else {
369            return false;
370        };
371        if assist.prompt_editor.is_none() {
372            return false;
373        }
374        assist.prompt_editor = None;
375        assist
376            .terminal
377            .update(cx, |this, cx| {
378                this.clear_block_below_cursor(cx);
379                this.focus_handle(cx).focus(window);
380            })
381            .is_ok()
382    }
383
384    fn insert_prompt_editor_into_terminal(
385        &mut self,
386        assist_id: TerminalInlineAssistId,
387        height: u8,
388        window: &mut Window,
389        cx: &mut App,
390    ) {
391        if let Some(assist) = self.assists.get_mut(&assist_id) {
392            if let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned() {
393                assist
394                    .terminal
395                    .update(cx, |terminal, cx| {
396                        terminal.clear_block_below_cursor(cx);
397                        let block = terminal_view::BlockProperties {
398                            height,
399                            render: Box::new(move |_| prompt_editor.clone().into_any_element()),
400                        };
401                        terminal.set_block_below_cursor(block, window, cx);
402                    })
403                    .log_err();
404            }
405        }
406    }
407}
408
409struct TerminalInlineAssist {
410    terminal: WeakEntity<TerminalView>,
411    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
412    codegen: Entity<TerminalCodegen>,
413    workspace: WeakEntity<Workspace>,
414    context_store: Entity<ContextStore>,
415    prompt_store: Option<Entity<PromptStore>>,
416    _subscriptions: Vec<Subscription>,
417}
418
419impl TerminalInlineAssist {
420    pub fn new(
421        assist_id: TerminalInlineAssistId,
422        terminal: &Entity<TerminalView>,
423        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
424        workspace: WeakEntity<Workspace>,
425        context_store: Entity<ContextStore>,
426        prompt_store: Option<Entity<PromptStore>>,
427        window: &mut Window,
428        cx: &mut App,
429    ) -> Self {
430        let codegen = prompt_editor.read(cx).codegen().clone();
431        Self {
432            terminal: terminal.downgrade(),
433            prompt_editor: Some(prompt_editor.clone()),
434            codegen: codegen.clone(),
435            workspace: workspace.clone(),
436            context_store,
437            prompt_store,
438            _subscriptions: vec![
439                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
440                    TerminalInlineAssistant::update_global(cx, |this, cx| {
441                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
442                    })
443                }),
444                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
445                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
446                        CodegenEvent::Finished => {
447                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
448                                assist
449                            } else {
450                                return;
451                            };
452
453                            if let CodegenStatus::Error(error) = &codegen.read(cx).status {
454                                if assist.prompt_editor.is_none() {
455                                    if let Some(workspace) = assist.workspace.upgrade() {
456                                        let error =
457                                            format!("Terminal inline assistant error: {}", error);
458                                        workspace.update(cx, |workspace, cx| {
459                                            struct InlineAssistantError;
460
461                                            let id =
462                                                NotificationId::composite::<InlineAssistantError>(
463                                                    assist_id.0,
464                                                );
465
466                                            workspace.show_toast(Toast::new(id, error), cx);
467                                        })
468                                    }
469                                }
470                            }
471
472                            if assist.prompt_editor.is_none() {
473                                this.finish_assist(assist_id, false, false, window, cx);
474                            }
475                        }
476                    })
477                }),
478            ],
479        }
480    }
481}