terminal_codegen.rs

  1use crate::inline_prompt_editor::CodegenStatus;
  2use client::telemetry::Telemetry;
  3use futures::{channel::mpsc, SinkExt, StreamExt};
  4use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
  5use language_model::{LanguageModelRegistry, LanguageModelRequest};
  6use language_models::report_assistant_event;
  7use std::{sync::Arc, time::Instant};
  8use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
  9use terminal::Terminal;
 10
 11pub struct TerminalCodegen {
 12    pub status: CodegenStatus,
 13    pub telemetry: Option<Arc<Telemetry>>,
 14    terminal: Entity<Terminal>,
 15    generation: Task<()>,
 16    pub message_id: Option<String>,
 17    transaction: Option<TerminalTransaction>,
 18}
 19
 20impl EventEmitter<CodegenEvent> for TerminalCodegen {}
 21
 22impl TerminalCodegen {
 23    pub fn new(terminal: Entity<Terminal>, telemetry: Option<Arc<Telemetry>>) -> Self {
 24        Self {
 25            terminal,
 26            telemetry,
 27            status: CodegenStatus::Idle,
 28            generation: Task::ready(()),
 29            message_id: None,
 30            transaction: None,
 31        }
 32    }
 33
 34    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
 35        let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
 36            return;
 37        };
 38
 39        let model_api_key = model.api_key(cx);
 40        let http_client = cx.http_client();
 41        let telemetry = self.telemetry.clone();
 42        self.status = CodegenStatus::Pending;
 43        self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
 44        self.generation = cx.spawn(|this, mut cx| async move {
 45            let model_telemetry_id = model.telemetry_id();
 46            let model_provider_id = model.provider_id();
 47            let response = model.stream_completion_text(prompt, &cx).await;
 48            let generate = async {
 49                let message_id = response
 50                    .as_ref()
 51                    .ok()
 52                    .and_then(|response| response.message_id.clone());
 53
 54                let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 55
 56                let task = cx.background_spawn({
 57                    let message_id = message_id.clone();
 58                    let executor = cx.background_executor().clone();
 59                    async move {
 60                        let mut response_latency = None;
 61                        let request_start = Instant::now();
 62                        let task = async {
 63                            let mut chunks = response?.stream;
 64                            while let Some(chunk) = chunks.next().await {
 65                                if response_latency.is_none() {
 66                                    response_latency = Some(request_start.elapsed());
 67                                }
 68                                let chunk = chunk?;
 69                                hunks_tx.send(chunk).await?;
 70                            }
 71
 72                            anyhow::Ok(())
 73                        };
 74
 75                        let result = task.await;
 76
 77                        let error_message = result.as_ref().err().map(|error| error.to_string());
 78                        report_assistant_event(
 79                            AssistantEvent {
 80                                conversation_id: None,
 81                                kind: AssistantKind::InlineTerminal,
 82                                message_id,
 83                                phase: AssistantPhase::Response,
 84                                model: model_telemetry_id,
 85                                model_provider: model_provider_id.to_string(),
 86                                response_latency,
 87                                error_message,
 88                                language_name: None,
 89                            },
 90                            telemetry,
 91                            http_client,
 92                            model_api_key,
 93                            &executor,
 94                        );
 95
 96                        result?;
 97                        anyhow::Ok(())
 98                    }
 99                });
100
101                this.update(&mut cx, |this, _| {
102                    this.message_id = message_id;
103                })?;
104
105                while let Some(hunk) = hunks_rx.next().await {
106                    this.update(&mut cx, |this, cx| {
107                        if let Some(transaction) = &mut this.transaction {
108                            transaction.push(hunk, cx);
109                            cx.notify();
110                        }
111                    })?;
112                }
113
114                task.await?;
115                anyhow::Ok(())
116            };
117
118            let result = generate.await;
119
120            this.update(&mut cx, |this, cx| {
121                if let Err(error) = result {
122                    this.status = CodegenStatus::Error(error);
123                } else {
124                    this.status = CodegenStatus::Done;
125                }
126                cx.emit(CodegenEvent::Finished);
127                cx.notify();
128            })
129            .ok();
130        });
131        cx.notify();
132    }
133
134    pub fn stop(&mut self, cx: &mut Context<Self>) {
135        self.status = CodegenStatus::Done;
136        self.generation = Task::ready(());
137        cx.emit(CodegenEvent::Finished);
138        cx.notify();
139    }
140
141    pub fn complete(&mut self, cx: &mut Context<Self>) {
142        if let Some(transaction) = self.transaction.take() {
143            transaction.complete(cx);
144        }
145    }
146
147    pub fn undo(&mut self, cx: &mut Context<Self>) {
148        if let Some(transaction) = self.transaction.take() {
149            transaction.undo(cx);
150        }
151    }
152}
153
154#[derive(Copy, Clone, Debug)]
155pub enum CodegenEvent {
156    Finished,
157}
158
159pub const CLEAR_INPUT: &str = "\x15";
160const CARRIAGE_RETURN: &str = "\x0d";
161
162struct TerminalTransaction {
163    terminal: Entity<Terminal>,
164}
165
166impl TerminalTransaction {
167    pub fn start(terminal: Entity<Terminal>) -> Self {
168        Self { terminal }
169    }
170
171    pub fn push(&mut self, hunk: String, cx: &mut App) {
172        // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
173        let input = Self::sanitize_input(hunk);
174        self.terminal
175            .update(cx, |terminal, _| terminal.input(input));
176    }
177
178    pub fn undo(&self, cx: &mut App) {
179        self.terminal
180            .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
181    }
182
183    pub fn complete(&self, cx: &mut App) {
184        self.terminal.update(cx, |terminal, _| {
185            terminal.input(CARRIAGE_RETURN.to_string())
186        });
187    }
188
189    fn sanitize_input(input: String) -> String {
190        input.replace(['\r', '\n'], "")
191    }
192}