terminal_inline_assistant.rs

  1use crate::{
  2    ThreadHistory,
  3    context::load_context,
  4    inline_prompt_editor::{
  5        CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  6    },
  7    terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen},
  8};
  9use agent::ThreadStore;
 10use agent_settings::AgentSettings;
 11use anyhow::{Context as _, Result};
 12
 13use collections::{HashMap, VecDeque};
 14use editor::{MultiBuffer, actions::SelectAll};
 15use fs::Fs;
 16use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
 17use language::Buffer;
 18use language_model::{
 19    CompletionIntent, ConfiguredModel, LanguageModelRegistry, LanguageModelRequest,
 20    LanguageModelRequestMessage, Role, report_anthropic_event,
 21};
 22use project::Project;
 23use prompt_store::{PromptBuilder, PromptStore};
 24use std::sync::Arc;
 25use terminal_view::TerminalView;
 26use ui::prelude::*;
 27use util::ResultExt;
 28use uuid::Uuid;
 29use workspace::{Toast, Workspace, notifications::NotificationId};
 30
 31pub fn init(fs: Arc<dyn Fs>, prompt_builder: Arc<PromptBuilder>, cx: &mut App) {
 32    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder));
 33}
 34
 35const DEFAULT_CONTEXT_LINES: usize = 50;
 36const PROMPT_HISTORY_MAX_LEN: usize = 20;
 37
 38pub struct TerminalInlineAssistant {
 39    next_assist_id: TerminalInlineAssistId,
 40    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 41    prompt_history: VecDeque<String>,
 42    fs: Arc<dyn Fs>,
 43    prompt_builder: Arc<PromptBuilder>,
 44}
 45
 46impl Global for TerminalInlineAssistant {}
 47
 48impl TerminalInlineAssistant {
 49    pub fn new(fs: Arc<dyn Fs>, prompt_builder: Arc<PromptBuilder>) -> Self {
 50        Self {
 51            next_assist_id: TerminalInlineAssistId::default(),
 52            assists: HashMap::default(),
 53            prompt_history: VecDeque::default(),
 54            fs,
 55            prompt_builder,
 56        }
 57    }
 58
 59    pub fn assist(
 60        &mut self,
 61        terminal_view: &Entity<TerminalView>,
 62        workspace: WeakEntity<Workspace>,
 63        project: WeakEntity<Project>,
 64        thread_store: Entity<ThreadStore>,
 65        prompt_store: Option<Entity<PromptStore>>,
 66        history: Option<WeakEntity<ThreadHistory>>,
 67        initial_prompt: Option<String>,
 68        window: &mut Window,
 69        cx: &mut App,
 70    ) {
 71        let terminal = terminal_view.read(cx).terminal().clone();
 72        let assist_id = self.next_assist_id.post_inc();
 73        let session_id = Uuid::new_v4();
 74        let prompt_buffer = cx.new(|cx| {
 75            MultiBuffer::singleton(
 76                cx.new(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx)),
 77                cx,
 78            )
 79        });
 80        let codegen = cx.new(|_| TerminalCodegen::new(terminal, session_id));
 81
 82        let prompt_editor = cx.new(|cx| {
 83            PromptEditor::new_terminal(
 84                assist_id,
 85                self.prompt_history.clone(),
 86                prompt_buffer.clone(),
 87                codegen,
 88                session_id,
 89                self.fs.clone(),
 90                thread_store.clone(),
 91                prompt_store.clone(),
 92                history,
 93                project.clone(),
 94                workspace.clone(),
 95                window,
 96                cx,
 97            )
 98        });
 99        let prompt_editor_render = prompt_editor.clone();
100        let block = terminal_view::BlockProperties {
101            height: 4,
102            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
103        };
104        terminal_view.update(cx, |terminal_view, cx| {
105            terminal_view.set_block_below_cursor(block, window, cx);
106        });
107
108        let terminal_assistant = TerminalInlineAssist::new(
109            assist_id,
110            terminal_view,
111            prompt_editor,
112            workspace.clone(),
113            window,
114            cx,
115        );
116
117        self.assists.insert(assist_id, terminal_assistant);
118
119        self.focus_assist(assist_id, window, cx);
120    }
121
122    fn focus_assist(
123        &mut self,
124        assist_id: TerminalInlineAssistId,
125        window: &mut Window,
126        cx: &mut App,
127    ) {
128        let assist = &self.assists[&assist_id];
129        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
130            prompt_editor.update(cx, |this, cx| {
131                this.editor.update(cx, |editor, cx| {
132                    window.focus(&editor.focus_handle(cx), cx);
133                    editor.select_all(&SelectAll, window, cx);
134                });
135            });
136        }
137    }
138
139    fn handle_prompt_editor_event(
140        &mut self,
141        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
142        event: &PromptEditorEvent,
143        window: &mut Window,
144        cx: &mut App,
145    ) {
146        let assist_id = prompt_editor.read(cx).id();
147        match event {
148            PromptEditorEvent::StartRequested => {
149                self.start_assist(assist_id, cx);
150            }
151            PromptEditorEvent::StopRequested => {
152                self.stop_assist(assist_id, cx);
153            }
154            PromptEditorEvent::ConfirmRequested { execute } => {
155                self.finish_assist(assist_id, false, *execute, window, cx);
156            }
157            PromptEditorEvent::CancelRequested => {
158                self.finish_assist(assist_id, true, false, window, cx);
159            }
160            PromptEditorEvent::Resized { height_in_lines } => {
161                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
162            }
163        }
164    }
165
166    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
167        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
168            assist
169        } else {
170            return;
171        };
172
173        let Some(user_prompt) = assist
174            .prompt_editor
175            .as_ref()
176            .map(|editor| editor.read(cx).prompt(cx))
177        else {
178            return;
179        };
180
181        self.prompt_history.retain(|prompt| *prompt != user_prompt);
182        self.prompt_history.push_back(user_prompt);
183        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
184            self.prompt_history.pop_front();
185        }
186
187        assist
188            .terminal
189            .update(cx, |terminal, cx| {
190                terminal
191                    .terminal()
192                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
193            })
194            .log_err();
195
196        let codegen = assist.codegen.clone();
197        let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
198            return;
199        };
200
201        codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
202    }
203
204    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
205        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
206            assist
207        } else {
208            return;
209        };
210
211        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
212    }
213
214    fn request_for_inline_assist(
215        &self,
216        assist_id: TerminalInlineAssistId,
217        cx: &mut App,
218    ) -> Result<Task<LanguageModelRequest>> {
219        let ConfiguredModel { model, .. } = LanguageModelRegistry::read_global(cx)
220            .inline_assistant_model()
221            .context("No inline assistant model")?;
222
223        let assist = self.assists.get(&assist_id).context("invalid assist")?;
224
225        let shell = std::env::var("SHELL").ok();
226        let (latest_output, working_directory) = assist
227            .terminal
228            .update(cx, |terminal, cx| {
229                let terminal = terminal.entity().read(cx);
230                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
231                let working_directory = terminal
232                    .working_directory()
233                    .map(|path| path.to_string_lossy().into_owned());
234                (latest_output, working_directory)
235            })
236            .ok()
237            .unwrap_or_default();
238
239        let prompt_editor = assist.prompt_editor.clone().context("invalid assist")?;
240
241        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
242            &prompt_editor.read(cx).prompt(cx),
243            shell.as_deref(),
244            working_directory.as_deref(),
245            &latest_output,
246        )?;
247
248        let temperature = AgentSettings::temperature_for_model(&model, cx);
249
250        let mention_set = prompt_editor.read(cx).mention_set().clone();
251        let load_context_task = load_context(&mention_set, cx);
252
253        Ok(cx.background_spawn(async move {
254            let mut request_message = LanguageModelRequestMessage {
255                role: Role::User,
256                content: vec![],
257                cache: false,
258                reasoning_details: None,
259            };
260
261            if let Some(context) = load_context_task.await {
262                context.add_to_request_message(&mut request_message);
263            }
264
265            request_message.content.push(prompt.into());
266
267            LanguageModelRequest {
268                thread_id: None,
269                prompt_id: None,
270                intent: Some(CompletionIntent::TerminalInlineAssist),
271                messages: vec![request_message],
272                tools: Vec::new(),
273                tool_choice: None,
274                stop: Vec::new(),
275                temperature,
276                thinking_allowed: false,
277                thinking_effort: None,
278                speed: None,
279            }
280        }))
281    }
282
283    fn finish_assist(
284        &mut self,
285        assist_id: TerminalInlineAssistId,
286        undo: bool,
287        execute: bool,
288        window: &mut Window,
289        cx: &mut App,
290    ) {
291        self.dismiss_assist(assist_id, window, cx);
292
293        if let Some(assist) = self.assists.remove(&assist_id) {
294            assist
295                .terminal
296                .update(cx, |this, cx| {
297                    this.clear_block_below_cursor(cx);
298                    this.focus_handle(cx).focus(window, cx);
299                })
300                .log_err();
301
302            if let Some(ConfiguredModel { model, .. }) =
303                LanguageModelRegistry::read_global(cx).inline_assistant_model()
304            {
305                let codegen = assist.codegen.read(cx);
306                let session_id = codegen.session_id();
307                let message_id = codegen.message_id.clone();
308                let model_telemetry_id = model.telemetry_id();
309                let model_provider_id = model.provider_id().to_string();
310
311                let (phase, event_type, anthropic_event_type) = if undo {
312                    (
313                        "rejected",
314                        "Assistant Response Rejected",
315                        language_model::AnthropicEventType::Reject,
316                    )
317                } else {
318                    (
319                        "accepted",
320                        "Assistant Response Accepted",
321                        language_model::AnthropicEventType::Accept,
322                    )
323                };
324
325                // Fire Zed telemetry
326                telemetry::event!(
327                    event_type,
328                    kind = "inline_terminal",
329                    phase = phase,
330                    model = model_telemetry_id,
331                    model_provider = model_provider_id,
332                    message_id = message_id,
333                    session_id = session_id,
334                );
335
336                report_anthropic_event(
337                    &model,
338                    language_model::AnthropicEventData {
339                        completion_type: language_model::AnthropicCompletionType::Terminal,
340                        event: anthropic_event_type,
341                        language_name: None,
342                        message_id,
343                    },
344                    cx,
345                );
346            }
347
348            assist.codegen.update(cx, |codegen, cx| {
349                if undo {
350                    codegen.undo(cx);
351                } else if execute {
352                    codegen.complete(cx);
353                }
354            });
355        }
356    }
357
358    fn dismiss_assist(
359        &mut self,
360        assist_id: TerminalInlineAssistId,
361        window: &mut Window,
362        cx: &mut App,
363    ) -> bool {
364        let Some(assist) = self.assists.get_mut(&assist_id) else {
365            return false;
366        };
367        if assist.prompt_editor.is_none() {
368            return false;
369        }
370        assist.prompt_editor = None;
371        assist
372            .terminal
373            .update(cx, |this, cx| {
374                this.clear_block_below_cursor(cx);
375                this.focus_handle(cx).focus(window, cx);
376            })
377            .is_ok()
378    }
379
380    fn insert_prompt_editor_into_terminal(
381        &mut self,
382        assist_id: TerminalInlineAssistId,
383        height: u8,
384        window: &mut Window,
385        cx: &mut App,
386    ) {
387        if let Some(assist) = self.assists.get_mut(&assist_id)
388            && let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned()
389        {
390            assist
391                .terminal
392                .update(cx, |terminal, cx| {
393                    terminal.clear_block_below_cursor(cx);
394                    let block = terminal_view::BlockProperties {
395                        height,
396                        render: Box::new(move |_| prompt_editor.clone().into_any_element()),
397                    };
398                    terminal.set_block_below_cursor(block, window, cx);
399                })
400                .log_err();
401        }
402    }
403}
404
405struct TerminalInlineAssist {
406    terminal: WeakEntity<TerminalView>,
407    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
408    codegen: Entity<TerminalCodegen>,
409    workspace: WeakEntity<Workspace>,
410    _subscriptions: Vec<Subscription>,
411}
412
413impl TerminalInlineAssist {
414    pub fn new(
415        assist_id: TerminalInlineAssistId,
416        terminal: &Entity<TerminalView>,
417        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
418        workspace: WeakEntity<Workspace>,
419        window: &mut Window,
420        cx: &mut App,
421    ) -> Self {
422        let codegen = prompt_editor.read(cx).codegen().clone();
423        Self {
424            terminal: terminal.downgrade(),
425            prompt_editor: Some(prompt_editor.clone()),
426            codegen: codegen.clone(),
427            workspace,
428            _subscriptions: vec![
429                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
430                    TerminalInlineAssistant::update_global(cx, |this, cx| {
431                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
432                    })
433                }),
434                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
435                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
436                        CodegenEvent::Finished => {
437                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
438                                assist
439                            } else {
440                                return;
441                            };
442
443                            if let CodegenStatus::Error(error) = &codegen.read(cx).status
444                                && assist.prompt_editor.is_none()
445                                && let Some(workspace) = assist.workspace.upgrade()
446                            {
447                                let error = format!("Terminal inline assistant error: {}", error);
448                                workspace.update(cx, |workspace, cx| {
449                                    struct InlineAssistantError;
450
451                                    let id = NotificationId::composite::<InlineAssistantError>(
452                                        assist_id.0,
453                                    );
454
455                                    workspace.show_toast(Toast::new(id, error), cx);
456                                })
457                            }
458
459                            if assist.prompt_editor.is_none() {
460                                this.finish_assist(assist_id, false, false, window, cx);
461                            }
462                        }
463                    })
464                }),
465            ],
466        }
467    }
468}