terminal_codegen.rs

  1use crate::inline_prompt_editor::CodegenStatus;
  2use futures::{SinkExt, StreamExt, channel::mpsc};
  3use gpui::{App, AppContext as _, Context, Entity, EventEmitter, Task};
  4use language_model::{ConfiguredModel, LanguageModelRegistry, LanguageModelRequest};
  5use language_models::provider::anthropic::telemetry::{
  6    AnthropicCompletionType, AnthropicEventData, AnthropicEventReporter, AnthropicEventType,
  7};
  8use std::time::Instant;
  9use terminal::Terminal;
 10use uuid::Uuid;
 11
 12pub struct TerminalCodegen {
 13    pub status: CodegenStatus,
 14    terminal: Entity<Terminal>,
 15    generation: Task<()>,
 16    pub message_id: Option<String>,
 17    transaction: Option<TerminalTransaction>,
 18    session_id: Uuid,
 19}
 20
 21impl EventEmitter<CodegenEvent> for TerminalCodegen {}
 22
 23impl TerminalCodegen {
 24    pub fn new(terminal: Entity<Terminal>, session_id: Uuid) -> Self {
 25        Self {
 26            terminal,
 27            status: CodegenStatus::Idle,
 28            generation: Task::ready(()),
 29            message_id: None,
 30            transaction: None,
 31            session_id,
 32        }
 33    }
 34
 35    pub fn session_id(&self) -> Uuid {
 36        self.session_id
 37    }
 38
 39    pub fn start(&mut self, prompt_task: Task<LanguageModelRequest>, cx: &mut Context<Self>) {
 40        let Some(ConfiguredModel { model, .. }) =
 41            LanguageModelRegistry::read_global(cx).inline_assistant_model()
 42        else {
 43            return;
 44        };
 45
 46        let anthropic_reporter = AnthropicEventReporter::new(&model, cx);
 47        let session_id = self.session_id;
 48        let model_telemetry_id = model.telemetry_id();
 49        let model_provider_id = model.provider_id().to_string();
 50
 51        self.status = CodegenStatus::Pending;
 52        self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
 53        self.generation = cx.spawn(async move |this, cx| {
 54            let prompt = prompt_task.await;
 55            let response = model.stream_completion_text(prompt, cx).await;
 56            let generate = async {
 57                let message_id = response
 58                    .as_ref()
 59                    .ok()
 60                    .and_then(|response| response.message_id.clone());
 61
 62                let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
 63
 64                let task = cx.background_spawn({
 65                    let message_id = message_id.clone();
 66                    let anthropic_reporter = anthropic_reporter.clone();
 67                    async move {
 68                        let mut response_latency = None;
 69                        let request_start = Instant::now();
 70                        let task = async {
 71                            let mut chunks = response?.stream;
 72                            while let Some(chunk) = chunks.next().await {
 73                                if response_latency.is_none() {
 74                                    response_latency = Some(request_start.elapsed());
 75                                }
 76                                let chunk = chunk?;
 77                                hunks_tx.send(chunk).await?;
 78                            }
 79
 80                            anyhow::Ok(())
 81                        };
 82
 83                        let result = task.await;
 84
 85                        let error_message = result.as_ref().err().map(|error| error.to_string());
 86
 87                        telemetry::event!(
 88                            "Assistant Responded",
 89                            session_id = session_id.to_string(),
 90                            kind = "inline_terminal",
 91                            phase = "response",
 92                            model = model_telemetry_id,
 93                            model_provider = model_provider_id,
 94                            language_name = Option::<&str>::None,
 95                            message_id = message_id,
 96                            response_latency = response_latency,
 97                            error_message = error_message,
 98                        );
 99
100                        anthropic_reporter.report(AnthropicEventData {
101                            completion_type: AnthropicCompletionType::Terminal,
102                            event: AnthropicEventType::Response,
103                            language_name: None,
104                            message_id,
105                        });
106
107                        result?;
108                        anyhow::Ok(())
109                    }
110                });
111
112                this.update(cx, |this, _| {
113                    this.message_id = message_id;
114                })?;
115
116                while let Some(hunk) = hunks_rx.next().await {
117                    this.update(cx, |this, cx| {
118                        if let Some(transaction) = &mut this.transaction {
119                            transaction.push(hunk, cx);
120                            cx.notify();
121                        }
122                    })?;
123                }
124
125                task.await?;
126                anyhow::Ok(())
127            };
128
129            let result = generate.await;
130
131            this.update(cx, |this, cx| {
132                if let Err(error) = result {
133                    this.status = CodegenStatus::Error(error);
134                } else {
135                    this.status = CodegenStatus::Done;
136                }
137                cx.emit(CodegenEvent::Finished);
138                cx.notify();
139            })
140            .ok();
141        });
142        cx.notify();
143    }
144
145    pub fn completion(&self) -> Option<String> {
146        self.transaction
147            .as_ref()
148            .map(|transaction| transaction.completion.clone())
149    }
150
151    pub fn stop(&mut self, cx: &mut Context<Self>) {
152        self.status = CodegenStatus::Done;
153        self.generation = Task::ready(());
154        cx.emit(CodegenEvent::Finished);
155        cx.notify();
156    }
157
158    pub fn complete(&mut self, cx: &mut Context<Self>) {
159        if let Some(transaction) = self.transaction.take() {
160            transaction.complete(cx);
161        }
162    }
163
164    pub fn undo(&mut self, cx: &mut Context<Self>) {
165        if let Some(transaction) = self.transaction.take() {
166            transaction.undo(cx);
167        }
168    }
169}
170
171#[derive(Copy, Clone, Debug)]
172pub enum CodegenEvent {
173    Finished,
174}
175
176#[cfg(not(target_os = "windows"))]
177pub const CLEAR_INPUT: &str = "\x15";
178#[cfg(target_os = "windows")]
179pub const CLEAR_INPUT: &str = "\x03";
180const CARRIAGE_RETURN: &str = "\x0d";
181
182struct TerminalTransaction {
183    completion: String,
184    terminal: Entity<Terminal>,
185}
186
187impl TerminalTransaction {
188    pub fn start(terminal: Entity<Terminal>) -> Self {
189        Self {
190            completion: String::new(),
191            terminal,
192        }
193    }
194
195    pub fn push(&mut self, hunk: String, cx: &mut App) {
196        // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
197        let input = Self::sanitize_input(hunk);
198        self.completion.push_str(&input);
199        self.terminal
200            .update(cx, |terminal, _| terminal.input(input.into_bytes()));
201    }
202
203    pub fn undo(self, cx: &mut App) {
204        self.terminal
205            .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.as_bytes()));
206    }
207
208    pub fn complete(self, cx: &mut App) {
209        self.terminal
210            .update(cx, |terminal, _| terminal.input(CARRIAGE_RETURN.as_bytes()));
211    }
212
213    fn sanitize_input(mut input: String) -> String {
214        input.retain(|c| c != '\r' && c != '\n');
215        input
216    }
217}