terminal_codegen.rs

  1use crate::inline_prompt_editor::CodegenStatus;
  2use client::telemetry::Telemetry;
  3use futures::{SinkExt, StreamExt, channel::mpsc};
  4use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
  5use language_model::{
  6    ConfiguredModel, LanguageModelRegistry, LanguageModelRequest, report_assistant_event,
  7};
  8use std::{sync::Arc, time::Instant};
  9use telemetry_events::{AssistantEventData, AssistantKind, AssistantPhase};
 10use terminal::Terminal;
 11
 12pub struct TerminalCodegen {
 13    pub status: CodegenStatus,
 14    pub telemetry: Option<Arc<Telemetry>>,
 15    terminal: Entity<Terminal>,
 16    generation: Task<()>,
 17    pub message_id: Option<String>,
 18    transaction: Option<TerminalTransaction>,
 19}
 20
 21impl EventEmitter<CodegenEvent> for TerminalCodegen {}
 22
 23impl TerminalCodegen {
 24    pub fn new(terminal: Entity<Terminal>, telemetry: Option<Arc<Telemetry>>) -> Self {
 25        Self {
 26            terminal,
 27            telemetry,
 28            status: CodegenStatus::Idle,
 29            generation: Task::ready(()),
 30            message_id: None,
 31            transaction: None,
 32        }
 33    }
 34
 35    pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut Context<Self>) {
 36        let Some(ConfiguredModel { model, .. }) =
 37            LanguageModelRegistry::read_global(cx).inline_assistant_model()
 38        else {
 39            return;
 40        };
 41
 42        let model_api_key = model.api_key(cx);
 43        let http_client = cx.http_client();
 44        let telemetry = self.telemetry.clone();
 45        self.status = CodegenStatus::Pending;
 46        self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
 47        self.generation = cx.spawn(async move |this, cx| {
 48            let model_telemetry_id = model.telemetry_id();
 49            let model_provider_id = model.provider_id();
 50            let response = model.stream_completion_text(prompt, &cx).await;
 51            let generate = async {
 52                let message_id = response
 53                    .as_ref()
 54                    .ok()
 55                    .and_then(|response| response.message_id.clone());
 56
 57                let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 58
 59                let task = cx.background_spawn({
 60                    let message_id = message_id.clone();
 61                    let executor = cx.background_executor().clone();
 62                    async move {
 63                        let mut response_latency = None;
 64                        let request_start = Instant::now();
 65                        let task = async {
 66                            let mut chunks = response?.stream;
 67                            while let Some(chunk) = chunks.next().await {
 68                                if response_latency.is_none() {
 69                                    response_latency = Some(request_start.elapsed());
 70                                }
 71                                let chunk = chunk?;
 72                                hunks_tx.send(chunk).await?;
 73                            }
 74
 75                            anyhow::Ok(())
 76                        };
 77
 78                        let result = task.await;
 79
 80                        let error_message = result.as_ref().err().map(|error| error.to_string());
 81                        report_assistant_event(
 82                            AssistantEventData {
 83                                conversation_id: None,
 84                                kind: AssistantKind::InlineTerminal,
 85                                message_id,
 86                                phase: AssistantPhase::Response,
 87                                model: model_telemetry_id,
 88                                model_provider: model_provider_id.to_string(),
 89                                response_latency,
 90                                error_message,
 91                                language_name: None,
 92                            },
 93                            telemetry,
 94                            http_client,
 95                            model_api_key,
 96                            &executor,
 97                        );
 98
 99                        result?;
100                        anyhow::Ok(())
101                    }
102                });
103
104                this.update(cx, |this, _| {
105                    this.message_id = message_id;
106                })?;
107
108                while let Some(hunk) = hunks_rx.next().await {
109                    this.update(cx, |this, cx| {
110                        if let Some(transaction) = &mut this.transaction {
111                            transaction.push(hunk, cx);
112                            cx.notify();
113                        }
114                    })?;
115                }
116
117                task.await?;
118                anyhow::Ok(())
119            };
120
121            let result = generate.await;
122
123            this.update(cx, |this, cx| {
124                if let Err(error) = result {
125                    this.status = CodegenStatus::Error(error);
126                } else {
127                    this.status = CodegenStatus::Done;
128                }
129                cx.emit(CodegenEvent::Finished);
130                cx.notify();
131            })
132            .ok();
133        });
134        cx.notify();
135    }
136
137    pub fn stop(&mut self, cx: &mut Context<Self>) {
138        self.status = CodegenStatus::Done;
139        self.generation = Task::ready(());
140        cx.emit(CodegenEvent::Finished);
141        cx.notify();
142    }
143
144    pub fn complete(&mut self, cx: &mut Context<Self>) {
145        if let Some(transaction) = self.transaction.take() {
146            transaction.complete(cx);
147        }
148    }
149
150    pub fn undo(&mut self, cx: &mut Context<Self>) {
151        if let Some(transaction) = self.transaction.take() {
152            transaction.undo(cx);
153        }
154    }
155}
156
157#[derive(Copy, Clone, Debug)]
158pub enum CodegenEvent {
159    Finished,
160}
161
162#[cfg(not(target_os = "windows"))]
163pub const CLEAR_INPUT: &str = "\x15";
164#[cfg(target_os = "windows")]
165pub const CLEAR_INPUT: &str = "\x03";
166const CARRIAGE_RETURN: &str = "\x0d";
167
168struct TerminalTransaction {
169    terminal: Entity<Terminal>,
170}
171
172impl TerminalTransaction {
173    pub fn start(terminal: Entity<Terminal>) -> Self {
174        Self { terminal }
175    }
176
177    pub fn push(&mut self, hunk: String, cx: &mut App) {
178        // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
179        let input = Self::sanitize_input(hunk);
180        self.terminal
181            .update(cx, |terminal, _| terminal.input(input));
182    }
183
184    pub fn undo(&self, cx: &mut App) {
185        self.terminal
186            .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
187    }
188
189    pub fn complete(&self, cx: &mut App) {
190        self.terminal.update(cx, |terminal, _| {
191            terminal.input(CARRIAGE_RETURN.to_string())
192        });
193    }
194
195    fn sanitize_input(input: String) -> String {
196        input.replace(['\r', '\n'], "")
197    }
198}