terminal_inline_assistant.rs

  1use crate::{
  2    ThreadHistory,
  3    context::load_context,
  4    inline_prompt_editor::{
  5        CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
  6    },
  7    terminal_codegen::{CLEAR_INPUT, CodegenEvent, TerminalCodegen},
  8};
  9use agent::ThreadStore;
 10use agent_settings::AgentSettings;
 11use anyhow::{Context as _, Result};
 12
 13use collections::{HashMap, VecDeque};
 14use editor::{MultiBuffer, actions::SelectAll};
 15use fs::Fs;
 16use gpui::{App, Entity, Focusable, Global, Subscription, Task, UpdateGlobal, WeakEntity};
 17use language::Buffer;
 18use language_model::{
 19    CompletionIntent, ConfiguredModel, LanguageModelRegistry, LanguageModelRequest,
 20    LanguageModelRequestMessage, Role,
 21};
 22use language_models::provider::anthropic::telemetry::{
 23    AnthropicCompletionType, AnthropicEventData, AnthropicEventType, report_anthropic_event,
 24};
 25use project::Project;
 26use prompt_store::{PromptBuilder, PromptStore};
 27use std::sync::Arc;
 28use terminal_view::TerminalView;
 29use ui::prelude::*;
 30use util::ResultExt;
 31use uuid::Uuid;
 32use workspace::{Toast, Workspace, notifications::NotificationId};
 33
 34pub fn init(fs: Arc<dyn Fs>, prompt_builder: Arc<PromptBuilder>, cx: &mut App) {
 35    cx.set_global(TerminalInlineAssistant::new(fs, prompt_builder));
 36}
 37
 38const DEFAULT_CONTEXT_LINES: usize = 50;
 39const PROMPT_HISTORY_MAX_LEN: usize = 20;
 40
 41pub struct TerminalInlineAssistant {
 42    next_assist_id: TerminalInlineAssistId,
 43    assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
 44    prompt_history: VecDeque<String>,
 45    fs: Arc<dyn Fs>,
 46    prompt_builder: Arc<PromptBuilder>,
 47}
 48
 49impl Global for TerminalInlineAssistant {}
 50
 51impl TerminalInlineAssistant {
 52    pub fn new(fs: Arc<dyn Fs>, prompt_builder: Arc<PromptBuilder>) -> Self {
 53        Self {
 54            next_assist_id: TerminalInlineAssistId::default(),
 55            assists: HashMap::default(),
 56            prompt_history: VecDeque::default(),
 57            fs,
 58            prompt_builder,
 59        }
 60    }
 61
 62    pub fn assist(
 63        &mut self,
 64        terminal_view: &Entity<TerminalView>,
 65        workspace: WeakEntity<Workspace>,
 66        project: WeakEntity<Project>,
 67        thread_store: Entity<ThreadStore>,
 68        prompt_store: Option<Entity<PromptStore>>,
 69        history: Option<WeakEntity<ThreadHistory>>,
 70        initial_prompt: Option<String>,
 71        window: &mut Window,
 72        cx: &mut App,
 73    ) {
 74        let terminal = terminal_view.read(cx).terminal().clone();
 75        let assist_id = self.next_assist_id.post_inc();
 76        let session_id = Uuid::new_v4();
 77        let prompt_buffer = cx.new(|cx| {
 78            MultiBuffer::singleton(
 79                cx.new(|cx| Buffer::local(initial_prompt.unwrap_or_default(), cx)),
 80                cx,
 81            )
 82        });
 83        let codegen = cx.new(|_| TerminalCodegen::new(terminal, session_id));
 84
 85        let prompt_editor = cx.new(|cx| {
 86            PromptEditor::new_terminal(
 87                assist_id,
 88                self.prompt_history.clone(),
 89                prompt_buffer.clone(),
 90                codegen,
 91                session_id,
 92                self.fs.clone(),
 93                thread_store.clone(),
 94                prompt_store.clone(),
 95                history,
 96                project.clone(),
 97                workspace.clone(),
 98                window,
 99                cx,
100            )
101        });
102        let prompt_editor_render = prompt_editor.clone();
103        let block = terminal_view::BlockProperties {
104            height: 4,
105            render: Box::new(move |_| prompt_editor_render.clone().into_any_element()),
106        };
107        terminal_view.update(cx, |terminal_view, cx| {
108            terminal_view.set_block_below_cursor(block, window, cx);
109        });
110
111        let terminal_assistant = TerminalInlineAssist::new(
112            assist_id,
113            terminal_view,
114            prompt_editor,
115            workspace.clone(),
116            window,
117            cx,
118        );
119
120        self.assists.insert(assist_id, terminal_assistant);
121
122        self.focus_assist(assist_id, window, cx);
123    }
124
125    fn focus_assist(
126        &mut self,
127        assist_id: TerminalInlineAssistId,
128        window: &mut Window,
129        cx: &mut App,
130    ) {
131        let assist = &self.assists[&assist_id];
132        if let Some(prompt_editor) = assist.prompt_editor.as_ref() {
133            prompt_editor.update(cx, |this, cx| {
134                this.editor.update(cx, |editor, cx| {
135                    window.focus(&editor.focus_handle(cx), cx);
136                    editor.select_all(&SelectAll, window, cx);
137                });
138            });
139        }
140    }
141
142    fn handle_prompt_editor_event(
143        &mut self,
144        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
145        event: &PromptEditorEvent,
146        window: &mut Window,
147        cx: &mut App,
148    ) {
149        let assist_id = prompt_editor.read(cx).id();
150        match event {
151            PromptEditorEvent::StartRequested => {
152                self.start_assist(assist_id, cx);
153            }
154            PromptEditorEvent::StopRequested => {
155                self.stop_assist(assist_id, cx);
156            }
157            PromptEditorEvent::ConfirmRequested { execute } => {
158                self.finish_assist(assist_id, false, *execute, window, cx);
159            }
160            PromptEditorEvent::CancelRequested => {
161                self.finish_assist(assist_id, true, false, window, cx);
162            }
163            PromptEditorEvent::Resized { height_in_lines } => {
164                self.insert_prompt_editor_into_terminal(assist_id, *height_in_lines, window, cx);
165            }
166        }
167    }
168
169    fn start_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
170        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
171            assist
172        } else {
173            return;
174        };
175
176        let Some(user_prompt) = assist
177            .prompt_editor
178            .as_ref()
179            .map(|editor| editor.read(cx).prompt(cx))
180        else {
181            return;
182        };
183
184        self.prompt_history.retain(|prompt| *prompt != user_prompt);
185        self.prompt_history.push_back(user_prompt);
186        if self.prompt_history.len() > PROMPT_HISTORY_MAX_LEN {
187            self.prompt_history.pop_front();
188        }
189
190        assist
191            .terminal
192            .update(cx, |terminal, cx| {
193                terminal
194                    .terminal()
195                    .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
196            })
197            .log_err();
198
199        let codegen = assist.codegen.clone();
200        let Some(request_task) = self.request_for_inline_assist(assist_id, cx).log_err() else {
201            return;
202        };
203
204        codegen.update(cx, |codegen, cx| codegen.start(request_task, cx));
205    }
206
207    fn stop_assist(&mut self, assist_id: TerminalInlineAssistId, cx: &mut App) {
208        let assist = if let Some(assist) = self.assists.get_mut(&assist_id) {
209            assist
210        } else {
211            return;
212        };
213
214        assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
215    }
216
217    fn request_for_inline_assist(
218        &self,
219        assist_id: TerminalInlineAssistId,
220        cx: &mut App,
221    ) -> Result<Task<LanguageModelRequest>> {
222        let ConfiguredModel { model, .. } = LanguageModelRegistry::read_global(cx)
223            .inline_assistant_model()
224            .context("No inline assistant model")?;
225
226        let assist = self.assists.get(&assist_id).context("invalid assist")?;
227
228        let shell = std::env::var("SHELL").ok();
229        let (latest_output, working_directory) = assist
230            .terminal
231            .update(cx, |terminal, cx| {
232                let terminal = terminal.entity().read(cx);
233                let latest_output = terminal.last_n_non_empty_lines(DEFAULT_CONTEXT_LINES);
234                let working_directory = terminal
235                    .working_directory()
236                    .map(|path| path.to_string_lossy().into_owned());
237                (latest_output, working_directory)
238            })
239            .ok()
240            .unwrap_or_default();
241
242        let prompt_editor = assist.prompt_editor.clone().context("invalid assist")?;
243
244        let prompt = self.prompt_builder.generate_terminal_assistant_prompt(
245            &prompt_editor.read(cx).prompt(cx),
246            shell.as_deref(),
247            working_directory.as_deref(),
248            &latest_output,
249        )?;
250
251        let temperature = AgentSettings::temperature_for_model(&model, cx);
252
253        let mention_set = prompt_editor.read(cx).mention_set().clone();
254        let load_context_task = load_context(&mention_set, cx);
255
256        Ok(cx.background_spawn(async move {
257            let mut request_message = LanguageModelRequestMessage {
258                role: Role::User,
259                content: vec![],
260                cache: false,
261                reasoning_details: None,
262            };
263
264            if let Some(context) = load_context_task.await {
265                context.add_to_request_message(&mut request_message);
266            }
267
268            request_message.content.push(prompt.into());
269
270            LanguageModelRequest {
271                thread_id: None,
272                prompt_id: None,
273                intent: Some(CompletionIntent::TerminalInlineAssist),
274                messages: vec![request_message],
275                tools: Vec::new(),
276                tool_choice: None,
277                stop: Vec::new(),
278                temperature,
279                thinking_allowed: false,
280                thinking_effort: None,
281                speed: None,
282            }
283        }))
284    }
285
286    fn finish_assist(
287        &mut self,
288        assist_id: TerminalInlineAssistId,
289        undo: bool,
290        execute: bool,
291        window: &mut Window,
292        cx: &mut App,
293    ) {
294        self.dismiss_assist(assist_id, window, cx);
295
296        if let Some(assist) = self.assists.remove(&assist_id) {
297            assist
298                .terminal
299                .update(cx, |this, cx| {
300                    this.clear_block_below_cursor(cx);
301                    this.focus_handle(cx).focus(window, cx);
302                })
303                .log_err();
304
305            if let Some(ConfiguredModel { model, .. }) =
306                LanguageModelRegistry::read_global(cx).inline_assistant_model()
307            {
308                let codegen = assist.codegen.read(cx);
309                let session_id = codegen.session_id();
310                let message_id = codegen.message_id.clone();
311                let model_telemetry_id = model.telemetry_id();
312                let model_provider_id = model.provider_id().to_string();
313
314                let (phase, event_type, anthropic_event_type) = if undo {
315                    (
316                        "rejected",
317                        "Assistant Response Rejected",
318                        AnthropicEventType::Reject,
319                    )
320                } else {
321                    (
322                        "accepted",
323                        "Assistant Response Accepted",
324                        AnthropicEventType::Accept,
325                    )
326                };
327
328                // Fire Zed telemetry
329                telemetry::event!(
330                    event_type,
331                    kind = "inline_terminal",
332                    phase = phase,
333                    model = model_telemetry_id,
334                    model_provider = model_provider_id,
335                    message_id = message_id,
336                    session_id = session_id,
337                );
338
339                report_anthropic_event(
340                    &model,
341                    AnthropicEventData {
342                        completion_type: AnthropicCompletionType::Terminal,
343                        event: anthropic_event_type,
344                        language_name: None,
345                        message_id,
346                    },
347                    cx,
348                );
349            }
350
351            assist.codegen.update(cx, |codegen, cx| {
352                if undo {
353                    codegen.undo(cx);
354                } else if execute {
355                    codegen.complete(cx);
356                }
357            });
358        }
359    }
360
361    fn dismiss_assist(
362        &mut self,
363        assist_id: TerminalInlineAssistId,
364        window: &mut Window,
365        cx: &mut App,
366    ) -> bool {
367        let Some(assist) = self.assists.get_mut(&assist_id) else {
368            return false;
369        };
370        if assist.prompt_editor.is_none() {
371            return false;
372        }
373        assist.prompt_editor = None;
374        assist
375            .terminal
376            .update(cx, |this, cx| {
377                this.clear_block_below_cursor(cx);
378                this.focus_handle(cx).focus(window, cx);
379            })
380            .is_ok()
381    }
382
383    fn insert_prompt_editor_into_terminal(
384        &mut self,
385        assist_id: TerminalInlineAssistId,
386        height: u8,
387        window: &mut Window,
388        cx: &mut App,
389    ) {
390        if let Some(assist) = self.assists.get_mut(&assist_id)
391            && let Some(prompt_editor) = assist.prompt_editor.as_ref().cloned()
392        {
393            assist
394                .terminal
395                .update(cx, |terminal, cx| {
396                    terminal.clear_block_below_cursor(cx);
397                    let block = terminal_view::BlockProperties {
398                        height,
399                        render: Box::new(move |_| prompt_editor.clone().into_any_element()),
400                    };
401                    terminal.set_block_below_cursor(block, window, cx);
402                })
403                .log_err();
404        }
405    }
406}
407
408struct TerminalInlineAssist {
409    terminal: WeakEntity<TerminalView>,
410    prompt_editor: Option<Entity<PromptEditor<TerminalCodegen>>>,
411    codegen: Entity<TerminalCodegen>,
412    workspace: WeakEntity<Workspace>,
413    _subscriptions: Vec<Subscription>,
414}
415
416impl TerminalInlineAssist {
417    pub fn new(
418        assist_id: TerminalInlineAssistId,
419        terminal: &Entity<TerminalView>,
420        prompt_editor: Entity<PromptEditor<TerminalCodegen>>,
421        workspace: WeakEntity<Workspace>,
422        window: &mut Window,
423        cx: &mut App,
424    ) -> Self {
425        let codegen = prompt_editor.read(cx).codegen().clone();
426        Self {
427            terminal: terminal.downgrade(),
428            prompt_editor: Some(prompt_editor.clone()),
429            codegen: codegen.clone(),
430            workspace,
431            _subscriptions: vec![
432                window.subscribe(&prompt_editor, cx, |prompt_editor, event, window, cx| {
433                    TerminalInlineAssistant::update_global(cx, |this, cx| {
434                        this.handle_prompt_editor_event(prompt_editor, event, window, cx)
435                    })
436                }),
437                window.subscribe(&codegen, cx, move |codegen, event, window, cx| {
438                    TerminalInlineAssistant::update_global(cx, |this, cx| match event {
439                        CodegenEvent::Finished => {
440                            let assist = if let Some(assist) = this.assists.get(&assist_id) {
441                                assist
442                            } else {
443                                return;
444                            };
445
446                            if let CodegenStatus::Error(error) = &codegen.read(cx).status
447                                && assist.prompt_editor.is_none()
448                                && let Some(workspace) = assist.workspace.upgrade()
449                            {
450                                let error = format!("Terminal inline assistant error: {}", error);
451                                workspace.update(cx, |workspace, cx| {
452                                    struct InlineAssistantError;
453
454                                    let id = NotificationId::composite::<InlineAssistantError>(
455                                        assist_id.0,
456                                    );
457
458                                    workspace.show_toast(Toast::new(id, error), cx);
459                                })
460                            }
461
462                            if assist.prompt_editor.is_none() {
463                                this.finish_assist(assist_id, false, false, window, cx);
464                            }
465                        }
466                    })
467                }),
468            ],
469        }
470    }
471}