Detailed changes
@@ -1,6 +1,7 @@
mod active_thread;
mod assistant_panel;
mod assistant_settings;
+mod buffer_codegen;
mod context;
mod context_picker;
mod context_store;
@@ -10,6 +11,7 @@ mod inline_prompt_editor;
mod message_editor;
mod prompts;
mod streaming_diff;
+mod terminal_codegen;
mod terminal_inline_assistant;
mod thread;
mod thread_history;
@@ -0,0 +1,1475 @@
+use crate::context::attach_context_to_message;
+use crate::context_store::ContextStore;
+use crate::inline_prompt_editor::CodegenStatus;
+use crate::{
+ prompts::PromptBuilder,
+ streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff},
+};
+use anyhow::{Context as _, Result};
+use client::telemetry::Telemetry;
+use collections::HashSet;
+use editor::{Anchor, AnchorRangeExt, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint};
+use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt};
+use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscription, Task};
+use language::{Buffer, IndentKind, Point, TransactionId};
+use language_model::{
+ LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
+ LanguageModelTextStream, Role,
+};
+use language_models::report_assistant_event;
+use multi_buffer::MultiBufferRow;
+use parking_lot::Mutex;
+use rope::Rope;
+use smol::future::FutureExt;
+use std::{
+ cmp,
+ future::Future,
+ iter,
+ ops::{Range, RangeInclusive},
+ pin::Pin,
+ sync::Arc,
+ task::{self, Poll},
+ time::Instant,
+};
+use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
+
+pub struct BufferCodegen {
+ alternatives: Vec<Model<CodegenAlternative>>,
+ pub active_alternative: usize,
+ seen_alternatives: HashSet<usize>,
+ subscriptions: Vec<Subscription>,
+ buffer: Model<MultiBuffer>,
+ range: Range<Anchor>,
+ initial_transaction_id: Option<TransactionId>,
+ context_store: Model<ContextStore>,
+ telemetry: Arc<Telemetry>,
+ builder: Arc<PromptBuilder>,
+ pub is_insertion: bool,
+}
+
+impl BufferCodegen {
+ pub fn new(
+ buffer: Model<MultiBuffer>,
+ range: Range<Anchor>,
+ initial_transaction_id: Option<TransactionId>,
+ context_store: Model<ContextStore>,
+ telemetry: Arc<Telemetry>,
+ builder: Arc<PromptBuilder>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let codegen = cx.new_model(|cx| {
+ CodegenAlternative::new(
+ buffer.clone(),
+ range.clone(),
+ false,
+ Some(context_store.clone()),
+ Some(telemetry.clone()),
+ builder.clone(),
+ cx,
+ )
+ });
+ let mut this = Self {
+ is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
+ alternatives: vec![codegen],
+ active_alternative: 0,
+ seen_alternatives: HashSet::default(),
+ subscriptions: Vec::new(),
+ buffer,
+ range,
+ initial_transaction_id,
+ context_store,
+ telemetry,
+ builder,
+ };
+ this.activate(0, cx);
+ this
+ }
+
+ fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
+ let codegen = self.active_alternative().clone();
+ self.subscriptions.clear();
+ self.subscriptions
+ .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
+ self.subscriptions
+ .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
+ }
+
+ pub fn active_alternative(&self) -> &Model<CodegenAlternative> {
+ &self.alternatives[self.active_alternative]
+ }
+
+ pub fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
+ &self.active_alternative().read(cx).status
+ }
+
+ pub fn alternative_count(&self, cx: &AppContext) -> usize {
+ LanguageModelRegistry::read_global(cx)
+ .inline_alternative_models()
+ .len()
+ + 1
+ }
+
+ pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
+ let next_active_ix = if self.active_alternative == 0 {
+ self.alternatives.len() - 1
+ } else {
+ self.active_alternative - 1
+ };
+ self.activate(next_active_ix, cx);
+ }
+
+ pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
+ let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
+ self.activate(next_active_ix, cx);
+ }
+
+ fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
+ self.active_alternative()
+ .update(cx, |codegen, cx| codegen.set_active(false, cx));
+ self.seen_alternatives.insert(index);
+ self.active_alternative = index;
+ self.active_alternative()
+ .update(cx, |codegen, cx| codegen.set_active(true, cx));
+ self.subscribe_to_alternative(cx);
+ cx.notify();
+ }
+
+ pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext<Self>) -> Result<()> {
+ let alternative_models = LanguageModelRegistry::read_global(cx)
+ .inline_alternative_models()
+ .to_vec();
+
+ self.active_alternative()
+ .update(cx, |alternative, cx| alternative.undo(cx));
+ self.activate(0, cx);
+ self.alternatives.truncate(1);
+
+ for _ in 0..alternative_models.len() {
+ self.alternatives.push(cx.new_model(|cx| {
+ CodegenAlternative::new(
+ self.buffer.clone(),
+ self.range.clone(),
+ false,
+ Some(self.context_store.clone()),
+ Some(self.telemetry.clone()),
+ self.builder.clone(),
+ cx,
+ )
+ }));
+ }
+
+ let primary_model = LanguageModelRegistry::read_global(cx)
+ .active_model()
+ .context("no active model")?;
+
+ for (model, alternative) in iter::once(primary_model)
+ .chain(alternative_models)
+ .zip(&self.alternatives)
+ {
+ alternative.update(cx, |alternative, cx| {
+ alternative.start(user_prompt.clone(), model.clone(), cx)
+ })?;
+ }
+
+ Ok(())
+ }
+
+ pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
+ for codegen in &self.alternatives {
+ codegen.update(cx, |codegen, cx| codegen.stop(cx));
+ }
+ }
+
+ pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
+ self.active_alternative()
+ .update(cx, |codegen, cx| codegen.undo(cx));
+
+ self.buffer.update(cx, |buffer, cx| {
+ if let Some(transaction_id) = self.initial_transaction_id.take() {
+ buffer.undo_transaction(transaction_id, cx);
+ buffer.refresh_preview(cx);
+ }
+ });
+ }
+
+ pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
+ self.active_alternative().read(cx).buffer.clone()
+ }
+
+ pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
+ self.active_alternative().read(cx).old_buffer.clone()
+ }
+
+ pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
+ self.active_alternative().read(cx).snapshot.clone()
+ }
+
+ pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
+ self.active_alternative().read(cx).edit_position
+ }
+
+ pub fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
+ &self.active_alternative().read(cx).diff
+ }
+
+ pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
+ self.active_alternative().read(cx).last_equal_ranges()
+ }
+}
+
+impl EventEmitter<CodegenEvent> for BufferCodegen {}
+
+pub struct CodegenAlternative {
+ buffer: Model<MultiBuffer>,
+ old_buffer: Model<Buffer>,
+ snapshot: MultiBufferSnapshot,
+ edit_position: Option<Anchor>,
+ range: Range<Anchor>,
+ last_equal_ranges: Vec<Range<Anchor>>,
+ transformation_transaction_id: Option<TransactionId>,
+ status: CodegenStatus,
+ generation: Task<()>,
+ diff: Diff,
+ context_store: Option<Model<ContextStore>>,
+ telemetry: Option<Arc<Telemetry>>,
+ _subscription: gpui::Subscription,
+ builder: Arc<PromptBuilder>,
+ active: bool,
+ edits: Vec<(Range<Anchor>, String)>,
+ line_operations: Vec<LineOperation>,
+ request: Option<LanguageModelRequest>,
+ elapsed_time: Option<f64>,
+ completion: Option<String>,
+ pub message_id: Option<String>,
+}
+
+impl EventEmitter<CodegenEvent> for CodegenAlternative {}
+
+impl CodegenAlternative {
+ pub fn new(
+ buffer: Model<MultiBuffer>,
+ range: Range<Anchor>,
+ active: bool,
+ context_store: Option<Model<ContextStore>>,
+ telemetry: Option<Arc<Telemetry>>,
+ builder: Arc<PromptBuilder>,
+ cx: &mut ModelContext<Self>,
+ ) -> Self {
+ let snapshot = buffer.read(cx).snapshot(cx);
+
+ let (old_buffer, _, _) = buffer
+ .read(cx)
+ .range_to_buffer_ranges(range.clone(), cx)
+ .pop()
+ .unwrap();
+ let old_buffer = cx.new_model(|cx| {
+ let old_buffer = old_buffer.read(cx);
+ let text = old_buffer.as_rope().clone();
+ let line_ending = old_buffer.line_ending();
+ let language = old_buffer.language().cloned();
+ let language_registry = old_buffer.language_registry();
+
+ let mut buffer = Buffer::local_normalized(text, line_ending, cx);
+ buffer.set_language(language, cx);
+ if let Some(language_registry) = language_registry {
+ buffer.set_language_registry(language_registry)
+ }
+ buffer
+ });
+
+ Self {
+ buffer: buffer.clone(),
+ old_buffer,
+ edit_position: None,
+ message_id: None,
+ snapshot,
+ last_equal_ranges: Default::default(),
+ transformation_transaction_id: None,
+ status: CodegenStatus::Idle,
+ generation: Task::ready(()),
+ diff: Diff::default(),
+ context_store,
+ telemetry,
+ _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
+ builder,
+ active,
+ edits: Vec::new(),
+ line_operations: Vec::new(),
+ range,
+ request: None,
+ elapsed_time: None,
+ completion: None,
+ }
+ }
+
+ pub fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
+ if active != self.active {
+ self.active = active;
+
+ if self.active {
+ let edits = self.edits.clone();
+ self.apply_edits(edits, cx);
+ if matches!(self.status, CodegenStatus::Pending) {
+ let line_operations = self.line_operations.clone();
+ self.reapply_line_based_diff(line_operations, cx);
+ } else {
+ self.reapply_batch_diff(cx).detach();
+ }
+ } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.undo_transaction(transaction_id, cx);
+ buffer.forget_transaction(transaction_id, cx);
+ });
+ }
+ }
+ }
+
+ fn handle_buffer_event(
+ &mut self,
+ _buffer: Model<MultiBuffer>,
+ event: &multi_buffer::Event,
+ cx: &mut ModelContext<Self>,
+ ) {
+ if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
+ if self.transformation_transaction_id == Some(*transaction_id) {
+ self.transformation_transaction_id = None;
+ self.generation = Task::ready(());
+ cx.emit(CodegenEvent::Undone);
+ }
+ }
+ }
+
+ pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
+ &self.last_equal_ranges
+ }
+
+ pub fn start(
+ &mut self,
+ user_prompt: String,
+ model: Arc<dyn LanguageModel>,
+ cx: &mut ModelContext<Self>,
+ ) -> Result<()> {
+ if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.undo_transaction(transformation_transaction_id, cx);
+ });
+ }
+
+ self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
+
+ let api_key = model.api_key(cx);
+ let telemetry_id = model.telemetry_id();
+ let provider_id = model.provider_id();
+ let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
+ if user_prompt.trim().to_lowercase() == "delete" {
+ async { Ok(LanguageModelTextStream::default()) }.boxed_local()
+ } else {
+ let request = self.build_request(user_prompt, cx)?;
+ self.request = Some(request.clone());
+
+ cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
+ .boxed_local()
+ };
+ self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
+ Ok(())
+ }
+
+ fn build_request(
+ &self,
+ user_prompt: String,
+ cx: &mut AppContext,
+ ) -> Result<LanguageModelRequest> {
+ let buffer = self.buffer.read(cx).snapshot(cx);
+ let language = buffer.language_at(self.range.start);
+ let language_name = if let Some(language) = language.as_ref() {
+ if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
+ None
+ } else {
+ Some(language.name())
+ }
+ } else {
+ None
+ };
+
+ let language_name = language_name.as_ref();
+ let start = buffer.point_to_buffer_offset(self.range.start);
+ let end = buffer.point_to_buffer_offset(self.range.end);
+ let (buffer, range) = if let Some((start, end)) = start.zip(end) {
+ let (start_buffer, start_buffer_offset) = start;
+ let (end_buffer, end_buffer_offset) = end;
+ if start_buffer.remote_id() == end_buffer.remote_id() {
+ (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
+ } else {
+ return Err(anyhow::anyhow!("invalid transformation range"));
+ }
+ } else {
+ return Err(anyhow::anyhow!("invalid transformation range"));
+ };
+
+ let prompt = self
+ .builder
+ .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
+ .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
+
+ let mut request_message = LanguageModelRequestMessage {
+ role: Role::User,
+ content: Vec::new(),
+ cache: false,
+ };
+
+ if let Some(context_store) = &self.context_store {
+ let context = context_store.update(cx, |this, _cx| this.context().clone());
+ attach_context_to_message(&mut request_message, context);
+ }
+
+ request_message.content.push(prompt.into());
+
+ Ok(LanguageModelRequest {
+ tools: Vec::new(),
+ stop: Vec::new(),
+ temperature: None,
+ messages: vec![request_message],
+ })
+ }
+
+ pub fn handle_stream(
+ &mut self,
+ model_telemetry_id: String,
+ model_provider_id: String,
+ model_api_key: Option<String>,
+ stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
+ cx: &mut ModelContext<Self>,
+ ) {
+ let start_time = Instant::now();
+ let snapshot = self.snapshot.clone();
+ let selected_text = snapshot
+ .text_for_range(self.range.start..self.range.end)
+ .collect::<Rope>();
+
+ let selection_start = self.range.start.to_point(&snapshot);
+
+ // Start with the indentation of the first line in the selection
+ let mut suggested_line_indent = snapshot
+ .suggested_indents(selection_start.row..=selection_start.row, cx)
+ .into_values()
+ .next()
+ .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
+
+ // If the first line in the selection does not have indentation, check the following lines
+ if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
+ for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
+ let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
+ // Prefer tabs if a line in the selection uses tabs as indentation
+ if line_indent.kind == IndentKind::Tab {
+ suggested_line_indent.kind = IndentKind::Tab;
+ break;
+ }
+ }
+ }
+
+ let http_client = cx.http_client().clone();
+ let telemetry = self.telemetry.clone();
+ let language_name = {
+ let multibuffer = self.buffer.read(cx);
+ let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
+ ranges
+ .first()
+ .and_then(|(buffer, _, _)| buffer.read(cx).language())
+ .map(|language| language.name())
+ };
+
+ self.diff = Diff::default();
+ self.status = CodegenStatus::Pending;
+ let mut edit_start = self.range.start.to_offset(&snapshot);
+ let completion = Arc::new(Mutex::new(String::new()));
+ let completion_clone = completion.clone();
+
+ self.generation = cx.spawn(|codegen, mut cx| {
+ async move {
+ let stream = stream.await;
+ let message_id = stream
+ .as_ref()
+ .ok()
+ .and_then(|stream| stream.message_id.clone());
+ let generate = async {
+ let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
+ let executor = cx.background_executor().clone();
+ let message_id = message_id.clone();
+ let line_based_stream_diff: Task<anyhow::Result<()>> =
+ cx.background_executor().spawn(async move {
+ let mut response_latency = None;
+ let request_start = Instant::now();
+ let diff = async {
+ let chunks = StripInvalidSpans::new(stream?.stream);
+ futures::pin_mut!(chunks);
+ let mut diff = StreamingDiff::new(selected_text.to_string());
+ let mut line_diff = LineDiff::default();
+
+ let mut new_text = String::new();
+ let mut base_indent = None;
+ let mut line_indent = None;
+ let mut first_line = true;
+
+ while let Some(chunk) = chunks.next().await {
+ if response_latency.is_none() {
+ response_latency = Some(request_start.elapsed());
+ }
+ let chunk = chunk?;
+ completion_clone.lock().push_str(&chunk);
+
+ let mut lines = chunk.split('\n').peekable();
+ while let Some(line) = lines.next() {
+ new_text.push_str(line);
+ if line_indent.is_none() {
+ if let Some(non_whitespace_ch_ix) =
+ new_text.find(|ch: char| !ch.is_whitespace())
+ {
+ line_indent = Some(non_whitespace_ch_ix);
+ base_indent = base_indent.or(line_indent);
+
+ let line_indent = line_indent.unwrap();
+ let base_indent = base_indent.unwrap();
+ let indent_delta =
+ line_indent as i32 - base_indent as i32;
+ let mut corrected_indent_len = cmp::max(
+ 0,
+ suggested_line_indent.len as i32 + indent_delta,
+ )
+ as usize;
+ if first_line {
+ corrected_indent_len = corrected_indent_len
+ .saturating_sub(
+ selection_start.column as usize,
+ );
+ }
+
+ let indent_char = suggested_line_indent.char();
+ let mut indent_buffer = [0; 4];
+ let indent_str =
+ indent_char.encode_utf8(&mut indent_buffer);
+ new_text.replace_range(
+ ..line_indent,
+ &indent_str.repeat(corrected_indent_len),
+ );
+ }
+ }
+
+ if line_indent.is_some() {
+ let char_ops = diff.push_new(&new_text);
+ line_diff
+ .push_char_operations(&char_ops, &selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+ new_text.clear();
+ }
+
+ if lines.peek().is_some() {
+ let char_ops = diff.push_new("\n");
+ line_diff
+ .push_char_operations(&char_ops, &selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+ if line_indent.is_none() {
+ // Don't write out the leading indentation in empty lines on the next line
+ // This is the case where the above if statement didn't clear the buffer
+ new_text.clear();
+ }
+ line_indent = None;
+ first_line = false;
+ }
+ }
+ }
+
+ let mut char_ops = diff.push_new(&new_text);
+ char_ops.extend(diff.finish());
+ line_diff.push_char_operations(&char_ops, &selected_text);
+ line_diff.finish(&selected_text);
+ diff_tx
+ .send((char_ops, line_diff.line_operations()))
+ .await?;
+
+ anyhow::Ok(())
+ };
+
+ let result = diff.await;
+
+ let error_message =
+ result.as_ref().err().map(|error| error.to_string());
+ report_assistant_event(
+ AssistantEvent {
+ conversation_id: None,
+ message_id,
+ kind: AssistantKind::Inline,
+ phase: AssistantPhase::Response,
+ model: model_telemetry_id,
+ model_provider: model_provider_id.to_string(),
+ response_latency,
+ error_message,
+ language_name: language_name.map(|name| name.to_proto()),
+ },
+ telemetry,
+ http_client,
+ model_api_key,
+ &executor,
+ );
+
+ result?;
+ Ok(())
+ });
+
+ while let Some((char_ops, line_ops)) = diff_rx.next().await {
+ codegen.update(&mut cx, |codegen, cx| {
+ codegen.last_equal_ranges.clear();
+
+ let edits = char_ops
+ .into_iter()
+ .filter_map(|operation| match operation {
+ CharOperation::Insert { text } => {
+ let edit_start = snapshot.anchor_after(edit_start);
+ Some((edit_start..edit_start, text))
+ }
+ CharOperation::Delete { bytes } => {
+ let edit_end = edit_start + bytes;
+ let edit_range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ edit_start = edit_end;
+ Some((edit_range, String::new()))
+ }
+ CharOperation::Keep { bytes } => {
+ let edit_end = edit_start + bytes;
+ let edit_range = snapshot.anchor_after(edit_start)
+ ..snapshot.anchor_before(edit_end);
+ edit_start = edit_end;
+ codegen.last_equal_ranges.push(edit_range);
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+
+ if codegen.active {
+ codegen.apply_edits(edits.iter().cloned(), cx);
+ codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
+ }
+ codegen.edits.extend(edits);
+ codegen.line_operations = line_ops;
+ codegen.edit_position = Some(snapshot.anchor_after(edit_start));
+
+ cx.notify();
+ })?;
+ }
+
+ // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
+ // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
+ // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
+ let batch_diff_task =
+ codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
+ let (line_based_stream_diff, ()) =
+ join!(line_based_stream_diff, batch_diff_task);
+ line_based_stream_diff?;
+
+ anyhow::Ok(())
+ };
+
+ let result = generate.await;
+ let elapsed_time = start_time.elapsed().as_secs_f64();
+
+ codegen
+ .update(&mut cx, |this, cx| {
+ this.message_id = message_id;
+ this.last_equal_ranges.clear();
+ if let Err(error) = result {
+ this.status = CodegenStatus::Error(error);
+ } else {
+ this.status = CodegenStatus::Done;
+ }
+ this.elapsed_time = Some(elapsed_time);
+ this.completion = Some(completion.lock().clone());
+ cx.emit(CodegenEvent::Finished);
+ cx.notify();
+ })
+ .ok();
+ }
+ });
+ cx.notify();
+ }
+
+ pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
+ self.last_equal_ranges.clear();
+ if self.diff.is_empty() {
+ self.status = CodegenStatus::Idle;
+ } else {
+ self.status = CodegenStatus::Done;
+ }
+ self.generation = Task::ready(());
+ cx.emit(CodegenEvent::Finished);
+ cx.notify();
+ }
+
+ pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
+ self.buffer.update(cx, |buffer, cx| {
+ if let Some(transaction_id) = self.transformation_transaction_id.take() {
+ buffer.undo_transaction(transaction_id, cx);
+ buffer.refresh_preview(cx);
+ }
+ });
+ }
+
+ fn apply_edits(
+ &mut self,
+ edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
+ cx: &mut ModelContext<CodegenAlternative>,
+ ) {
+ let transaction = self.buffer.update(cx, |buffer, cx| {
+ // Avoid grouping assistant edits with user edits.
+ buffer.finalize_last_transaction(cx);
+ buffer.start_transaction(cx);
+ buffer.edit(edits, None, cx);
+ buffer.end_transaction(cx)
+ });
+
+ if let Some(transaction) = transaction {
+ if let Some(first_transaction) = self.transformation_transaction_id {
+ // Group all assistant edits into the first transaction.
+ self.buffer.update(cx, |buffer, cx| {
+ buffer.merge_transactions(transaction, first_transaction, cx)
+ });
+ } else {
+ self.transformation_transaction_id = Some(transaction);
+ self.buffer
+ .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
+ }
+ }
+ }
+
+ fn reapply_line_based_diff(
+ &mut self,
+ line_operations: impl IntoIterator<Item = LineOperation>,
+ cx: &mut ModelContext<Self>,
+ ) {
+ let old_snapshot = self.snapshot.clone();
+ let old_range = self.range.to_point(&old_snapshot);
+ let new_snapshot = self.buffer.read(cx).snapshot(cx);
+ let new_range = self.range.to_point(&new_snapshot);
+
+ let mut old_row = old_range.start.row;
+ let mut new_row = new_range.start.row;
+
+ self.diff.deleted_row_ranges.clear();
+ self.diff.inserted_row_ranges.clear();
+ for operation in line_operations {
+ match operation {
+ LineOperation::Keep { lines } => {
+ old_row += lines;
+ new_row += lines;
+ }
+ LineOperation::Delete { lines } => {
+ let old_end_row = old_row + lines - 1;
+ let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
+
+ if let Some((_, last_deleted_row_range)) =
+ self.diff.deleted_row_ranges.last_mut()
+ {
+ if *last_deleted_row_range.end() + 1 == old_row {
+ *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
+ } else {
+ self.diff
+ .deleted_row_ranges
+ .push((new_row, old_row..=old_end_row));
+ }
+ } else {
+ self.diff
+ .deleted_row_ranges
+ .push((new_row, old_row..=old_end_row));
+ }
+
+ old_row += lines;
+ }
+ LineOperation::Insert { lines } => {
+ let new_end_row = new_row + lines - 1;
+ let start = new_snapshot.anchor_before(Point::new(new_row, 0));
+ let end = new_snapshot.anchor_before(Point::new(
+ new_end_row,
+ new_snapshot.line_len(MultiBufferRow(new_end_row)),
+ ));
+ self.diff.inserted_row_ranges.push(start..end);
+ new_row += lines;
+ }
+ }
+
+ cx.notify();
+ }
+ }
+
+ fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
+ let old_snapshot = self.snapshot.clone();
+ let old_range = self.range.to_point(&old_snapshot);
+ let new_snapshot = self.buffer.read(cx).snapshot(cx);
+ let new_range = self.range.to_point(&new_snapshot);
+
+ cx.spawn(|codegen, mut cx| async move {
+ let (deleted_row_ranges, inserted_row_ranges) = cx
+ .background_executor()
+ .spawn(async move {
+ let old_text = old_snapshot
+ .text_for_range(
+ Point::new(old_range.start.row, 0)
+ ..Point::new(
+ old_range.end.row,
+ old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
+ ),
+ )
+ .collect::<String>();
+ let new_text = new_snapshot
+ .text_for_range(
+ Point::new(new_range.start.row, 0)
+ ..Point::new(
+ new_range.end.row,
+ new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
+ ),
+ )
+ .collect::<String>();
+
+ let mut old_row = old_range.start.row;
+ let mut new_row = new_range.start.row;
+ let batch_diff =
+ similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
+
+ let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
+ let mut inserted_row_ranges = Vec::new();
+ for change in batch_diff.iter_all_changes() {
+ let line_count = change.value().lines().count() as u32;
+ match change.tag() {
+ similar::ChangeTag::Equal => {
+ old_row += line_count;
+ new_row += line_count;
+ }
+ similar::ChangeTag::Delete => {
+ let old_end_row = old_row + line_count - 1;
+ let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
+
+ if let Some((_, last_deleted_row_range)) =
+ deleted_row_ranges.last_mut()
+ {
+ if *last_deleted_row_range.end() + 1 == old_row {
+ *last_deleted_row_range =
+ *last_deleted_row_range.start()..=old_end_row;
+ } else {
+ deleted_row_ranges.push((new_row, old_row..=old_end_row));
+ }
+ } else {
+ deleted_row_ranges.push((new_row, old_row..=old_end_row));
+ }
+
+ old_row += line_count;
+ }
+ similar::ChangeTag::Insert => {
+ let new_end_row = new_row + line_count - 1;
+ let start = new_snapshot.anchor_before(Point::new(new_row, 0));
+ let end = new_snapshot.anchor_before(Point::new(
+ new_end_row,
+ new_snapshot.line_len(MultiBufferRow(new_end_row)),
+ ));
+ inserted_row_ranges.push(start..end);
+ new_row += line_count;
+ }
+ }
+ }
+
+ (deleted_row_ranges, inserted_row_ranges)
+ })
+ .await;
+
+ codegen
+ .update(&mut cx, |codegen, cx| {
+ codegen.diff.deleted_row_ranges = deleted_row_ranges;
+ codegen.diff.inserted_row_ranges = inserted_row_ranges;
+ cx.notify();
+ })
+ .ok();
+ })
+ }
+}
+
+#[derive(Copy, Clone, Debug)]
+pub enum CodegenEvent {
+ Finished,
+ Undone,
+}
+
+struct StripInvalidSpans<T> {
+ stream: T,
+ stream_done: bool,
+ buffer: String,
+ first_line: bool,
+ line_end: bool,
+ starts_with_code_block: bool,
+}
+
+impl<T> StripInvalidSpans<T>
+where
+ T: Stream<Item = Result<String>>,
+{
+ fn new(stream: T) -> Self {
+ Self {
+ stream,
+ stream_done: false,
+ buffer: String::new(),
+ first_line: true,
+ line_end: false,
+ starts_with_code_block: false,
+ }
+ }
+}
+
+impl<T> Stream for StripInvalidSpans<T>
+where
+ T: Stream<Item = Result<String>>,
+{
+ type Item = Result<String>;
+
+ fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
+ const CODE_BLOCK_DELIMITER: &str = "```";
+ const CURSOR_SPAN: &str = "<|CURSOR|>";
+
+ let this = unsafe { self.get_unchecked_mut() };
+ loop {
+ if !this.stream_done {
+ let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
+ match stream.as_mut().poll_next(cx) {
+ Poll::Ready(Some(Ok(chunk))) => {
+ this.buffer.push_str(&chunk);
+ }
+ Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
+ Poll::Ready(None) => {
+ this.stream_done = true;
+ }
+ Poll::Pending => return Poll::Pending,
+ }
+ }
+
+ let mut chunk = String::new();
+ let mut consumed = 0;
+ if !this.buffer.is_empty() {
+ let mut lines = this.buffer.split('\n').enumerate().peekable();
+ while let Some((line_ix, line)) = lines.next() {
+ if line_ix > 0 {
+ this.first_line = false;
+ }
+
+ if this.first_line {
+ let trimmed_line = line.trim();
+ if lines.peek().is_some() {
+ if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
+ consumed += line.len() + 1;
+ this.starts_with_code_block = true;
+ continue;
+ }
+ } else if trimmed_line.is_empty()
+ || prefixes(CODE_BLOCK_DELIMITER)
+ .any(|prefix| trimmed_line.starts_with(prefix))
+ {
+ break;
+ }
+ }
+
+ let line_without_cursor = line.replace(CURSOR_SPAN, "");
+ if lines.peek().is_some() {
+ if this.line_end {
+ chunk.push('\n');
+ }
+
+ chunk.push_str(&line_without_cursor);
+ this.line_end = true;
+ consumed += line.len() + 1;
+ } else if this.stream_done {
+ if !this.starts_with_code_block
+ || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
+ {
+ if this.line_end {
+ chunk.push('\n');
+ }
+
+ chunk.push_str(&line);
+ }
+
+ consumed += line.len();
+ } else {
+ let trimmed_line = line.trim();
+ if trimmed_line.is_empty()
+ || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
+ || prefixes(CODE_BLOCK_DELIMITER)
+ .any(|prefix| trimmed_line.ends_with(prefix))
+ {
+ break;
+ } else {
+ if this.line_end {
+ chunk.push('\n');
+ this.line_end = false;
+ }
+
+ chunk.push_str(&line_without_cursor);
+ consumed += line.len();
+ }
+ }
+ }
+ }
+
+ this.buffer = this.buffer.split_off(consumed);
+ if !chunk.is_empty() {
+ return Poll::Ready(Some(Ok(chunk)));
+ } else if this.stream_done {
+ return Poll::Ready(None);
+ }
+ }
+ }
+}
+
+fn prefixes(text: &str) -> impl Iterator<Item = &str> {
+ (0..text.len() - 1).map(|ix| &text[..ix + 1])
+}
+
+#[derive(Default)]
+pub struct Diff {
+ pub deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
+ pub inserted_row_ranges: Vec<Range<Anchor>>,
+}
+
+impl Diff {
+ fn is_empty(&self) -> bool {
+ self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use futures::{
+ stream::{self},
+ Stream,
+ };
+ use gpui::{Context, TestAppContext};
+ use indoc::indoc;
+ use language::{
+ language_settings, tree_sitter_rust, Buffer, Language, LanguageConfig, LanguageMatcher,
+ Point,
+ };
+ use language_model::LanguageModelRegistry;
+ use rand::prelude::*;
+ use serde::Serialize;
+ use settings::SettingsStore;
+ use std::{future, sync::Arc};
+
+ #[derive(Serialize)]
+ pub struct DummyCompletionRequest {
+ pub name: String,
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_transform_autoindent(cx: &mut TestAppContext, mut rng: StdRng) {
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_model::LanguageModelRegistry::test);
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ let x = 0;
+ for _ in 0..10 {
+ x += 1;
+ }
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(4, 5))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ CodegenAlternative::new(
+ buffer.clone(),
+ range.clone(),
+ true,
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let chunks_tx = simulate_response_stream(codegen.clone(), cx);
+
+ let mut new_text = concat!(
+ " let mut x = 0;\n",
+ " while x < 10 {\n",
+ " x += 1;\n",
+ " }",
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_when_generating_past_indentation(
+ cx: &mut TestAppContext,
+ mut rng: StdRng,
+ ) {
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ le
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 6))..snapshot.anchor_after(Point::new(1, 6))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ CodegenAlternative::new(
+ buffer.clone(),
+ range.clone(),
+ true,
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let chunks_tx = simulate_response_stream(codegen.clone(), cx);
+
+ cx.background_executor.run_until_parked();
+
+ let mut new_text = concat!(
+ "t mut x = 0;\n",
+ "while x < 10 {\n",
+ " x += 1;\n",
+ "}", //
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_when_generating_before_indentation(
+ cx: &mut TestAppContext,
+ mut rng: StdRng,
+ ) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = concat!(
+ "fn main() {\n",
+ " \n",
+ "}\n" //
+ );
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 2))..snapshot.anchor_after(Point::new(1, 2))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ CodegenAlternative::new(
+ buffer.clone(),
+ range.clone(),
+ true,
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let chunks_tx = simulate_response_stream(codegen.clone(), cx);
+
+ cx.background_executor.run_until_parked();
+
+ let mut new_text = concat!(
+ "let mut x = 0;\n",
+ "while x < 10 {\n",
+ " x += 1;\n",
+ "}", //
+ );
+ while !new_text.is_empty() {
+ let max_len = cmp::min(new_text.len(), 10);
+ let len = rng.gen_range(1..=max_len);
+ let (chunk, suffix) = new_text.split_at(len);
+ chunks_tx.unbounded_send(chunk.to_string()).unwrap();
+ new_text = suffix;
+ cx.background_executor.run_until_parked();
+ }
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ while x < 10 {
+ x += 1;
+ }
+ }
+ "}
+ );
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_autoindent_respects_tabs_in_selection(cx: &mut TestAppContext) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ func main() {
+ \tx := 0
+ \tfor i := 0; i < 10; i++ {
+ \t\tx++
+ \t}
+ }
+ "};
+ let buffer = cx.new_model(|cx| Buffer::local(text, cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(0, 0))..snapshot.anchor_after(Point::new(4, 2))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ CodegenAlternative::new(
+ buffer.clone(),
+ range.clone(),
+ true,
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let chunks_tx = simulate_response_stream(codegen.clone(), cx);
+ let new_text = concat!(
+ "func main() {\n",
+ "\tx := 0\n",
+ "\tfor x < 10 {\n",
+ "\t\tx++\n",
+ "\t}", //
+ );
+ chunks_tx.unbounded_send(new_text.to_string()).unwrap();
+ drop(chunks_tx);
+ cx.background_executor.run_until_parked();
+
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ func main() {
+ \tx := 0
+ \tfor x < 10 {
+ \t\tx++
+ \t}
+ }
+ "}
+ );
+ }
+
+ #[gpui::test]
+ async fn test_inactive_codegen_alternative(cx: &mut TestAppContext) {
+ cx.update(LanguageModelRegistry::test);
+ cx.set_global(cx.update(SettingsStore::test));
+ cx.update(language_settings::init);
+
+ let text = indoc! {"
+ fn main() {
+ let x = 0;
+ }
+ "};
+ let buffer =
+ cx.new_model(|cx| Buffer::local(text, cx).with_language(Arc::new(rust_lang()), cx));
+ let buffer = cx.new_model(|cx| MultiBuffer::singleton(buffer, cx));
+ let range = buffer.read_with(cx, |buffer, cx| {
+ let snapshot = buffer.snapshot(cx);
+ snapshot.anchor_before(Point::new(1, 0))..snapshot.anchor_after(Point::new(1, 14))
+ });
+ let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+ let codegen = cx.new_model(|cx| {
+ CodegenAlternative::new(
+ buffer.clone(),
+ range.clone(),
+ false,
+ None,
+ None,
+ prompt_builder,
+ cx,
+ )
+ });
+
+ let chunks_tx = simulate_response_stream(codegen.clone(), cx);
+ chunks_tx
+ .unbounded_send("let mut x = 0;\nx += 1;".to_string())
+ .unwrap();
+ drop(chunks_tx);
+ cx.run_until_parked();
+
+ // The codegen is inactive, so the buffer doesn't get modified.
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ text
+ );
+
+ // Activating the codegen applies the changes.
+ codegen.update(cx, |codegen, cx| codegen.set_active(true, cx));
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ indoc! {"
+ fn main() {
+ let mut x = 0;
+ x += 1;
+ }
+ "}
+ );
+
+ // Deactivating the codegen undoes the changes.
+ codegen.update(cx, |codegen, cx| codegen.set_active(false, cx));
+ cx.run_until_parked();
+ assert_eq!(
+ buffer.read_with(cx, |buffer, cx| buffer.snapshot(cx).text()),
+ text
+ );
+ }
+
+ #[gpui::test]
+ async fn test_strip_invalid_spans_from_codeblock() {
+ assert_chunks("Lorem ipsum dolor", "Lorem ipsum dolor").await;
+ assert_chunks("```\nLorem ipsum dolor", "Lorem ipsum dolor").await;
+ assert_chunks("```\nLorem ipsum dolor\n```", "Lorem ipsum dolor").await;
+ assert_chunks(
+ "```html\n```js\nLorem ipsum dolor\n```\n```",
+ "```js\nLorem ipsum dolor\n```",
+ )
+ .await;
+ assert_chunks("``\nLorem ipsum dolor\n```", "``\nLorem ipsum dolor\n```").await;
+ assert_chunks("Lorem<|CURSOR|> ipsum", "Lorem ipsum").await;
+ assert_chunks("Lorem ipsum", "Lorem ipsum").await;
+ assert_chunks("```\n<|CURSOR|>Lorem ipsum\n```", "Lorem ipsum").await;
+
+ async fn assert_chunks(text: &str, expected_text: &str) {
+ for chunk_size in 1..=text.len() {
+ let actual_text = StripInvalidSpans::new(chunks(text, chunk_size))
+ .map(|chunk| chunk.unwrap())
+ .collect::<String>()
+ .await;
+ assert_eq!(
+ actual_text, expected_text,
+ "failed to strip invalid spans, chunk size: {}",
+ chunk_size
+ );
+ }
+ }
+
+ fn chunks(text: &str, size: usize) -> impl Stream<Item = Result<String>> {
+ stream::iter(
+ text.chars()
+ .collect::<Vec<_>>()
+ .chunks(size)
+ .map(|chunk| Ok(chunk.iter().collect::<String>()))
+ .collect::<Vec<_>>(),
+ )
+ }
+ }
+
+ fn simulate_response_stream(
+ codegen: Model<CodegenAlternative>,
+ cx: &mut TestAppContext,
+ ) -> mpsc::UnboundedSender<String> {
+ let (chunks_tx, chunks_rx) = mpsc::unbounded();
+ codegen.update(cx, |codegen, cx| {
+ codegen.handle_stream(
+ String::new(),
+ String::new(),
+ None,
+ future::ready(Ok(LanguageModelTextStream {
+ message_id: None,
+ stream: chunks_rx.map(Ok).boxed(),
+ })),
+ cx,
+ );
+ });
+ chunks_tx
+ }
+
+ fn rust_lang() -> Language {
+ Language::new(
+ LanguageConfig {
+ name: "Rust".into(),
+ matcher: LanguageMatcher {
+ path_suffixes: vec!["rs".to_string()],
+ ..Default::default()
+ },
+ ..Default::default()
+ },
+ Some(tree_sitter_rust::LANGUAGE.into()),
+ )
+ .with_indents_query(
+ r#"
+ (call_expression) @indent
+ (field_expression) @indent
+ (_ "(" ")" @end) @indent
+ (_ "{" "}" @end) @indent
+ "#,
+ )
+ .unwrap()
+ }
+}
@@ -1,73 +1,44 @@
-use crate::context::attach_context_to_message;
-use crate::context_picker::ContextPicker;
+use crate::buffer_codegen::{BufferCodegen, CodegenAlternative, CodegenEvent};
use crate::context_store::ContextStore;
-use crate::context_strip::ContextStrip;
-use crate::inline_prompt_editor::{
- render_cancel_button, CodegenStatus, PromptEditorEvent, PromptMode,
-};
+use crate::inline_prompt_editor::{CodegenStatus, InlineAssistId, PromptEditor, PromptEditorEvent};
use crate::thread_store::ThreadStore;
+use crate::AssistantPanel;
use crate::{
- assistant_settings::AssistantSettings,
- prompts::PromptBuilder,
- streaming_diff::{CharOperation, LineDiff, LineOperation, StreamingDiff},
+ assistant_settings::AssistantSettings, prompts::PromptBuilder,
terminal_inline_assistant::TerminalInlineAssistant,
- CycleNextInlineAssist, CyclePreviousInlineAssist,
};
-use crate::{AssistantPanel, ToggleContextPicker};
use anyhow::{Context as _, Result};
-use client::{telemetry::Telemetry, ErrorExt};
+use client::telemetry::Telemetry;
use collections::{hash_map, HashMap, HashSet, VecDeque};
use editor::{
- actions::{MoveDown, MoveUp, SelectAll},
+ actions::SelectAll,
display_map::{
BlockContext, BlockPlacement, BlockProperties, BlockStyle, CustomBlockId, RenderBlock,
ToDisplayPoint,
},
- Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorElement, EditorEvent, EditorMode,
- EditorStyle, ExcerptId, ExcerptRange, GutterDimensions, MultiBuffer, MultiBufferSnapshot,
- ToOffset as _, ToPoint,
+ Anchor, AnchorRangeExt, CodeActionProvider, Editor, EditorEvent, ExcerptId, ExcerptRange,
+ GutterDimensions, MultiBuffer, MultiBufferSnapshot, ToOffset as _, ToPoint,
};
-use feature_flags::{FeatureFlagAppExt as _, ZedPro};
use fs::Fs;
-use futures::{channel::mpsc, future::LocalBoxFuture, join, SinkExt, Stream, StreamExt};
+use util::ResultExt;
+
use gpui::{
- anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter,
- FocusHandle, FocusableView, FontWeight, Global, HighlightStyle, Model, ModelContext,
- Subscription, Task, TextStyle, UpdateGlobal, View, ViewContext, WeakModel, WeakView,
- WindowContext,
+ point, AppContext, FocusableView, Global, HighlightStyle, Model, Subscription, Task,
+ UpdateGlobal, View, ViewContext, WeakModel, WeakView, WindowContext,
};
-use language::{Buffer, IndentKind, Point, Selection, TransactionId};
-use language_model::{
- LanguageModel, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
- LanguageModelTextStream, Role,
-};
-use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
+use language::{Buffer, Point, Selection, TransactionId};
+use language_model::LanguageModelRegistry;
use language_models::report_assistant_event;
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use project::{CodeAction, ProjectTransaction};
-use rope::Rope;
-use settings::{update_settings_file, Settings, SettingsStore};
-use smol::future::FutureExt;
-use std::{
- cmp,
- future::Future,
- iter, mem,
- ops::{Range, RangeInclusive},
- pin::Pin,
- rc::Rc,
- sync::Arc,
- task::{self, Poll},
- time::Instant,
-};
+use settings::{Settings, SettingsStore};
+use std::{cmp, mem, ops::Range, rc::Rc, sync::Arc};
use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
use terminal_view::{terminal_panel::TerminalPanel, TerminalView};
use text::{OffsetRangeExt, ToPoint as _};
-use theme::ThemeSettings;
-use ui::{
- prelude::*, CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip,
-};
-use util::{RangeExt, ResultExt};
+use ui::prelude::*;
+use util::RangeExt;
use workspace::{dock::Panel, ShowConfiguration};
use workspace::{notifications::NotificationId, ItemHandle, Toast, Workspace};
@@ -366,7 +337,7 @@ impl InlineAssistant {
let assist_id = self.next_assist_id.post_inc();
let context_store = cx.new_model(|_cx| ContextStore::new());
let codegen = cx.new_model(|cx| {
- Codegen::new(
+ BufferCodegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
None,
@@ -379,7 +350,7 @@ impl InlineAssistant {
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
let prompt_editor = cx.new_view(|cx| {
- PromptEditor::new(
+ PromptEditor::new_buffer(
assist_id,
gutter_dimensions.clone(),
self.prompt_history.clone(),
@@ -422,6 +393,8 @@ impl InlineAssistant {
.or_insert_with(|| EditorInlineAssists::new(&editor, cx));
let mut assist_group = InlineAssistGroup::new();
for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
+ let codegen = prompt_editor.read(cx).codegen().clone();
+
self.assists.insert(
assist_id,
InlineAssist::new(
@@ -432,7 +405,7 @@ impl InlineAssistant {
prompt_block_id,
end_block_id,
range,
- prompt_editor.read(cx).codegen.clone(),
+ codegen,
workspace.clone(),
cx,
),
@@ -475,7 +448,7 @@ impl InlineAssistant {
let context_store = cx.new_model(|_cx| ContextStore::new());
let codegen = cx.new_model(|cx| {
- Codegen::new(
+ BufferCodegen::new(
editor.read(cx).buffer().clone(),
range.clone(),
initial_transaction_id,
@@ -488,7 +461,7 @@ impl InlineAssistant {
let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
let prompt_editor = cx.new_view(|cx| {
- PromptEditor::new(
+ PromptEditor::new_buffer(
assist_id,
gutter_dimensions.clone(),
self.prompt_history.clone(),
@@ -521,7 +494,7 @@ impl InlineAssistant {
prompt_block_id,
end_block_id,
range,
- prompt_editor.read(cx).codegen.clone(),
+ codegen.clone(),
workspace.clone(),
cx,
),
@@ -541,7 +514,7 @@ impl InlineAssistant {
&self,
editor: &View<Editor>,
range: &Range<Anchor>,
- prompt_editor: &View<PromptEditor>,
+ prompt_editor: &View<PromptEditor<BufferCodegen>>,
cx: &mut WindowContext,
) -> [CustomBlockId; 2] {
let prompt_editor_height = prompt_editor.update(cx, |prompt_editor, cx| {
@@ -643,11 +616,11 @@ impl InlineAssistant {
fn handle_prompt_editor_event(
&mut self,
- prompt_editor: View<PromptEditor>,
+ prompt_editor: View<PromptEditor<BufferCodegen>>,
event: &PromptEditorEvent,
cx: &mut WindowContext,
) {
- let assist_id = prompt_editor.read(cx).id;
+ let assist_id = prompt_editor.read(cx).id();
match event {
PromptEditorEvent::StartRequested => {
self.start_assist(assist_id, cx);
@@ -665,7 +638,7 @@ impl InlineAssistant {
self.dismiss_assist(assist_id, cx);
}
PromptEditorEvent::Resized { .. } => {
- // This only matters for the terminal inline
+ // This only matters for the terminal inline assistant
}
}
}
@@ -1451,25 +1424,17 @@ impl InlineAssistGroup {
}
}
-fn build_assist_editor_renderer(editor: &View<PromptEditor>) -> RenderBlock {
+fn build_assist_editor_renderer(editor: &View<PromptEditor<BufferCodegen>>) -> RenderBlock {
let editor = editor.clone();
+
Arc::new(move |cx: &mut BlockContext| {
- *editor.read(cx).gutter_dimensions.lock() = *cx.gutter_dimensions;
+ let gutter_dimensions = editor.read(cx).gutter_dimensions();
+
+ *gutter_dimensions.lock() = *cx.gutter_dimensions;
editor.clone().into_any_element()
})
}
-#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
-pub struct InlineAssistId(usize);
-
-impl InlineAssistId {
- fn post_inc(&mut self) -> InlineAssistId {
- let id = *self;
- self.0 += 1;
- id
- }
-}
-
#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
struct InlineAssistGroupId(usize);
@@ -1481,689 +1446,12 @@ impl InlineAssistGroupId {
}
}
-struct PromptEditor {
- id: InlineAssistId,
- editor: View<Editor>,
- context_strip: View<ContextStrip>,
- context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
- language_model_selector: View<LanguageModelSelector>,
- edited_since_done: bool,
- gutter_dimensions: Arc<Mutex<GutterDimensions>>,
- prompt_history: VecDeque<String>,
- prompt_history_ix: Option<usize>,
- pending_prompt: String,
- codegen: Model<Codegen>,
- _codegen_subscription: Subscription,
- editor_subscriptions: Vec<Subscription>,
- show_rate_limit_notice: bool,
-}
-
-impl EventEmitter<PromptEditorEvent> for PromptEditor {}
-
-impl Render for PromptEditor {
- fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- let gutter_dimensions = *self.gutter_dimensions.lock();
- let mut buttons = Vec::new();
- let codegen = self.codegen.read(cx);
- if codegen.alternative_count(cx) > 1 {
- buttons.push(self.render_cycle_controls(cx));
- }
- let prompt_mode = if codegen.is_insertion {
- PromptMode::Generate {
- supports_execute: false,
- }
- } else {
- PromptMode::Transform
- };
-
- buttons.extend(render_cancel_button(
- codegen.status(cx).into(),
- self.edited_since_done,
- prompt_mode,
- cx,
- ));
-
- v_flex()
- .border_y_1()
- .border_color(cx.theme().status().info_border)
- .size_full()
- .py(cx.line_height() / 2.5)
- .child(
- h_flex()
- .key_context("PromptEditor")
- .bg(cx.theme().colors().editor_background)
- .block_mouse_down()
- .cursor(CursorStyle::Arrow)
- .on_action(cx.listener(Self::toggle_context_picker))
- .on_action(cx.listener(Self::confirm))
- .on_action(cx.listener(Self::cancel))
- .on_action(cx.listener(Self::move_up))
- .on_action(cx.listener(Self::move_down))
- .capture_action(cx.listener(Self::cycle_prev))
- .capture_action(cx.listener(Self::cycle_next))
- .child(
- h_flex()
- .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
- .justify_center()
- .gap_2()
- .child(LanguageModelSelectorPopoverMenu::new(
- self.language_model_selector.clone(),
- IconButton::new("context", IconName::SettingsAlt)
- .shape(IconButtonShape::Square)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .tooltip(move |cx| {
- Tooltip::with_meta(
- format!(
- "Using {}",
- LanguageModelRegistry::read_global(cx)
- .active_model()
- .map(|model| model.name().0)
- .unwrap_or_else(|| "No model selected".into()),
- ),
- None,
- "Change Model",
- cx,
- )
- }),
- ))
- .map(|el| {
- let CodegenStatus::Error(error) = self.codegen.read(cx).status(cx)
- else {
- return el;
- };
-
- let error_message = SharedString::from(error.to_string());
- if error.error_code() == proto::ErrorCode::RateLimitExceeded
- && cx.has_flag::<ZedPro>()
- {
- el.child(
- v_flex()
- .child(
- IconButton::new(
- "rate-limit-error",
- IconName::XCircle,
- )
- .toggle_state(self.show_rate_limit_notice)
- .shape(IconButtonShape::Square)
- .icon_size(IconSize::Small)
- .on_click(
- cx.listener(Self::toggle_rate_limit_notice),
- ),
- )
- .children(self.show_rate_limit_notice.then(|| {
- deferred(
- anchored()
- .position_mode(
- gpui::AnchoredPositionMode::Local,
- )
- .position(point(px(0.), px(24.)))
- .anchor(gpui::Corner::TopLeft)
- .child(self.render_rate_limit_notice(cx)),
- )
- })),
- )
- } else {
- el.child(
- div()
- .id("error")
- .tooltip(move |cx| {
- Tooltip::text(error_message.clone(), cx)
- })
- .child(
- Icon::new(IconName::XCircle)
- .size(IconSize::Small)
- .color(Color::Error),
- ),
- )
- }
- }),
- )
- .child(div().flex_1().child(self.render_editor(cx)))
- .child(h_flex().gap_2().pr_6().children(buttons)),
- )
- .child(
- h_flex()
- .child(
- h_flex()
- .w(gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0))
- .justify_center()
- .gap_2(),
- )
- .child(self.context_strip.clone()),
- )
- }
-}
-
-impl FocusableView for PromptEditor {
- fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
- self.editor.focus_handle(cx)
- }
-}
-
-impl PromptEditor {
- const MAX_LINES: u8 = 8;
-
- #[allow(clippy::too_many_arguments)]
- fn new(
- id: InlineAssistId,
- gutter_dimensions: Arc<Mutex<GutterDimensions>>,
- prompt_history: VecDeque<String>,
- prompt_buffer: Model<MultiBuffer>,
- codegen: Model<Codegen>,
- fs: Arc<dyn Fs>,
- context_store: Model<ContextStore>,
- workspace: WeakView<Workspace>,
- thread_store: Option<WeakModel<ThreadStore>>,
- cx: &mut ViewContext<Self>,
- ) -> Self {
- let prompt_editor = cx.new_view(|cx| {
- let mut editor = Editor::new(
- EditorMode::AutoHeight {
- max_lines: Self::MAX_LINES as usize,
- },
- prompt_buffer,
- None,
- false,
- cx,
- );
- editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
- // Since the prompt editors for all inline assistants are linked,
- // always show the cursor (even when it isn't focused) because
- // typing in one will make what you typed appear in all of them.
- editor.set_show_cursor_when_unfocused(true, cx);
- editor.set_placeholder_text(Self::placeholder_text(codegen.read(cx), cx), cx);
- editor
- });
- let context_picker_menu_handle = PopoverMenuHandle::default();
-
- let mut this = Self {
- id,
- editor: prompt_editor.clone(),
- context_strip: cx.new_view(|cx| {
- ContextStrip::new(
- context_store,
- workspace.clone(),
- thread_store.clone(),
- prompt_editor.focus_handle(cx),
- context_picker_menu_handle.clone(),
- cx,
- )
- }),
- context_picker_menu_handle,
- language_model_selector: cx.new_view(|cx| {
- let fs = fs.clone();
- LanguageModelSelector::new(
- move |model, cx| {
- update_settings_file::<AssistantSettings>(
- fs.clone(),
- cx,
- move |settings, _| settings.set_model(model.clone()),
- );
- },
- cx,
- )
- }),
- edited_since_done: false,
- gutter_dimensions,
- prompt_history,
- prompt_history_ix: None,
- pending_prompt: String::new(),
- _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
- editor_subscriptions: Vec::new(),
- codegen,
- show_rate_limit_notice: false,
- };
- this.subscribe_to_editor(cx);
- this
- }
-
- fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
- self.editor_subscriptions.clear();
- self.editor_subscriptions
- .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
- }
-
- fn set_show_cursor_when_unfocused(
- &mut self,
- show_cursor_when_unfocused: bool,
- cx: &mut ViewContext<Self>,
- ) {
- self.editor.update(cx, |editor, cx| {
- editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
- });
- }
-
- fn unlink(&mut self, cx: &mut ViewContext<Self>) {
- let prompt = self.prompt(cx);
- let focus = self.editor.focus_handle(cx).contains_focused(cx);
- self.editor = cx.new_view(|cx| {
- let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
- editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
- editor.set_placeholder_text(Self::placeholder_text(self.codegen.read(cx), cx), cx);
- editor.set_placeholder_text("Add a promptβ¦", cx);
- editor.set_text(prompt, cx);
- if focus {
- editor.focus(cx);
- }
- editor
- });
- self.subscribe_to_editor(cx);
- }
-
- fn placeholder_text(codegen: &Codegen, cx: &WindowContext) -> String {
- let action = if codegen.is_insertion {
- "Generate"
- } else {
- "Transform"
- };
- let assistant_panel_keybinding = ui::text_for_action(&crate::ToggleFocus, cx)
- .map(|keybinding| format!("{keybinding} to chat β "))
- .unwrap_or_default();
-
- format!("{action}β¦ ({assistant_panel_keybinding}ββ for history)")
- }
-
- fn prompt(&self, cx: &AppContext) -> String {
- self.editor.read(cx).text(cx)
- }
-
- fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
- self.show_rate_limit_notice = !self.show_rate_limit_notice;
- if self.show_rate_limit_notice {
- cx.focus_view(&self.editor);
- }
- cx.notify();
- }
-
- fn handle_prompt_editor_events(
- &mut self,
- _: View<Editor>,
- event: &EditorEvent,
- cx: &mut ViewContext<Self>,
- ) {
- match event {
- EditorEvent::Edited { .. } => {
- if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
- workspace
- .update(cx, |workspace, cx| {
- let is_via_ssh = workspace
- .project()
- .update(cx, |project, _| project.is_via_ssh());
-
- workspace
- .client()
- .telemetry()
- .log_edit_event("inline assist", is_via_ssh);
- })
- .log_err();
- }
- let prompt = self.editor.read(cx).text(cx);
- if self
- .prompt_history_ix
- .map_or(true, |ix| self.prompt_history[ix] != prompt)
- {
- self.prompt_history_ix.take();
- self.pending_prompt = prompt;
- }
-
- self.edited_since_done = true;
- cx.notify();
- }
- EditorEvent::Blurred => {
- if self.show_rate_limit_notice {
- self.show_rate_limit_notice = false;
- cx.notify();
- }
- }
- _ => {}
- }
- }
-
- fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
- match self.codegen.read(cx).status(cx) {
- CodegenStatus::Idle => {
- self.editor
- .update(cx, |editor, _| editor.set_read_only(false));
- }
- CodegenStatus::Pending => {
- self.editor
- .update(cx, |editor, _| editor.set_read_only(true));
- }
- CodegenStatus::Done => {
- self.edited_since_done = false;
- self.editor
- .update(cx, |editor, _| editor.set_read_only(false));
- }
- CodegenStatus::Error(error) => {
- if cx.has_flag::<ZedPro>()
- && error.error_code() == proto::ErrorCode::RateLimitExceeded
- && !dismissed_rate_limit_notice()
- {
- self.show_rate_limit_notice = true;
- cx.notify();
- }
-
- self.edited_since_done = false;
- self.editor
- .update(cx, |editor, _| editor.set_read_only(false));
- }
- }
- }
-
- fn toggle_context_picker(&mut self, _: &ToggleContextPicker, cx: &mut ViewContext<Self>) {
- self.context_picker_menu_handle.toggle(cx);
- }
-
- fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
- match self.codegen.read(cx).status(cx) {
- CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
- cx.emit(PromptEditorEvent::CancelRequested);
- }
- CodegenStatus::Pending => {
- cx.emit(PromptEditorEvent::StopRequested);
- }
- }
- }
-
- fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
- match self.codegen.read(cx).status(cx) {
- CodegenStatus::Idle => {
- cx.emit(PromptEditorEvent::StartRequested);
- }
- CodegenStatus::Pending => {
- cx.emit(PromptEditorEvent::DismissRequested);
- }
- CodegenStatus::Done => {
- if self.edited_since_done {
- cx.emit(PromptEditorEvent::StartRequested);
- } else {
- cx.emit(PromptEditorEvent::ConfirmRequested { execute: false });
- }
- }
- CodegenStatus::Error(_) => {
- cx.emit(PromptEditorEvent::StartRequested);
- }
- }
- }
-
- fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
- if let Some(ix) = self.prompt_history_ix {
- if ix > 0 {
- self.prompt_history_ix = Some(ix - 1);
- let prompt = self.prompt_history[ix - 1].as_str();
- self.editor.update(cx, |editor, cx| {
- editor.set_text(prompt, cx);
- editor.move_to_beginning(&Default::default(), cx);
- });
- }
- } else if !self.prompt_history.is_empty() {
- self.prompt_history_ix = Some(self.prompt_history.len() - 1);
- let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
- self.editor.update(cx, |editor, cx| {
- editor.set_text(prompt, cx);
- editor.move_to_beginning(&Default::default(), cx);
- });
- }
- }
-
- fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
- if let Some(ix) = self.prompt_history_ix {
- if ix < self.prompt_history.len() - 1 {
- self.prompt_history_ix = Some(ix + 1);
- let prompt = self.prompt_history[ix + 1].as_str();
- self.editor.update(cx, |editor, cx| {
- editor.set_text(prompt, cx);
- editor.move_to_end(&Default::default(), cx)
- });
- } else {
- self.prompt_history_ix = None;
- let prompt = self.pending_prompt.as_str();
- self.editor.update(cx, |editor, cx| {
- editor.set_text(prompt, cx);
- editor.move_to_end(&Default::default(), cx)
- });
- }
- }
- }
-
- fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
- self.codegen
- .update(cx, |codegen, cx| codegen.cycle_prev(cx));
- }
-
- fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
- self.codegen
- .update(cx, |codegen, cx| codegen.cycle_next(cx));
- }
-
- fn render_cycle_controls(&self, cx: &ViewContext<Self>) -> AnyElement {
- let codegen = self.codegen.read(cx);
- let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
-
- let model_registry = LanguageModelRegistry::read_global(cx);
- let default_model = model_registry.active_model();
- let alternative_models = model_registry.inline_alternative_models();
-
- let get_model_name = |index: usize| -> String {
- let name = |model: &Arc<dyn LanguageModel>| model.name().0.to_string();
-
- match index {
- 0 => default_model.as_ref().map_or_else(String::new, name),
- index if index <= alternative_models.len() => alternative_models
- .get(index - 1)
- .map_or_else(String::new, name),
- _ => String::new(),
- }
- };
-
- let total_models = alternative_models.len() + 1;
-
- if total_models <= 1 {
- return div().into_any_element();
- }
-
- let current_index = codegen.active_alternative;
- let prev_index = (current_index + total_models - 1) % total_models;
- let next_index = (current_index + 1) % total_models;
-
- let prev_model_name = get_model_name(prev_index);
- let next_model_name = get_model_name(next_index);
-
- h_flex()
- .child(
- IconButton::new("previous", IconName::ChevronLeft)
- .icon_color(Color::Muted)
- .disabled(disabled || current_index == 0)
- .shape(IconButtonShape::Square)
- .tooltip({
- let focus_handle = self.editor.focus_handle(cx);
- move |cx| {
- cx.new_view(|cx| {
- let mut tooltip = Tooltip::new("Previous Alternative").key_binding(
- KeyBinding::for_action_in(
- &CyclePreviousInlineAssist,
- &focus_handle,
- cx,
- ),
- );
- if !disabled && current_index != 0 {
- tooltip = tooltip.meta(prev_model_name.clone());
- }
- tooltip
- })
- .into()
- }
- })
- .on_click(cx.listener(|this, _, cx| {
- this.codegen
- .update(cx, |codegen, cx| codegen.cycle_prev(cx))
- })),
- )
- .child(
- Label::new(format!(
- "{}/{}",
- codegen.active_alternative + 1,
- codegen.alternative_count(cx)
- ))
- .size(LabelSize::Small)
- .color(if disabled {
- Color::Disabled
- } else {
- Color::Muted
- }),
- )
- .child(
- IconButton::new("next", IconName::ChevronRight)
- .icon_color(Color::Muted)
- .disabled(disabled || current_index == total_models - 1)
- .shape(IconButtonShape::Square)
- .tooltip({
- let focus_handle = self.editor.focus_handle(cx);
- move |cx| {
- cx.new_view(|cx| {
- let mut tooltip = Tooltip::new("Next Alternative").key_binding(
- KeyBinding::for_action_in(
- &CycleNextInlineAssist,
- &focus_handle,
- cx,
- ),
- );
- if !disabled && current_index != total_models - 1 {
- tooltip = tooltip.meta(next_model_name.clone());
- }
- tooltip
- })
- .into()
- }
- })
- .on_click(cx.listener(|this, _, cx| {
- this.codegen
- .update(cx, |codegen, cx| codegen.cycle_next(cx))
- })),
- )
- .into_any_element()
- }
-
- fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- Popover::new().child(
- v_flex()
- .occlude()
- .p_2()
- .child(
- Label::new("Out of Tokens")
- .size(LabelSize::Small)
- .weight(FontWeight::BOLD),
- )
- .child(Label::new(
- "Try Zed Pro for higher limits, a wider range of models, and more.",
- ))
- .child(
- h_flex()
- .justify_between()
- .child(CheckboxWithLabel::new(
- "dont-show-again",
- Label::new("Don't show again"),
- if dismissed_rate_limit_notice() {
- ui::ToggleState::Selected
- } else {
- ui::ToggleState::Unselected
- },
- |selection, cx| {
- let is_dismissed = match selection {
- ui::ToggleState::Unselected => false,
- ui::ToggleState::Indeterminate => return,
- ui::ToggleState::Selected => true,
- };
-
- set_rate_limit_notice_dismissed(is_dismissed, cx)
- },
- ))
- .child(
- h_flex()
- .gap_2()
- .child(
- Button::new("dismiss", "Dismiss")
- .style(ButtonStyle::Transparent)
- .on_click(cx.listener(Self::toggle_rate_limit_notice)),
- )
- .child(Button::new("more-info", "More Info").on_click(
- |_event, cx| {
- cx.dispatch_action(Box::new(
- zed_actions::OpenAccountSettings,
- ))
- },
- )),
- ),
- ),
- )
- }
-
- fn render_editor(&mut self, cx: &mut ViewContext<Self>) -> AnyElement {
- let font_size = TextSize::Default.rems(cx);
- let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
-
- v_flex()
- .key_context("MessageEditor")
- .size_full()
- .gap_2()
- .p_2()
- .bg(cx.theme().colors().editor_background)
- .child({
- let settings = ThemeSettings::get_global(cx);
- let text_style = TextStyle {
- color: cx.theme().colors().editor_foreground,
- font_family: settings.ui_font.family.clone(),
- font_features: settings.ui_font.features.clone(),
- font_size: font_size.into(),
- font_weight: settings.ui_font.weight,
- line_height: line_height.into(),
- ..Default::default()
- };
-
- EditorElement::new(
- &self.editor,
- EditorStyle {
- background: cx.theme().colors().editor_background,
- local_player: cx.theme().players().local(),
- text: text_style,
- ..Default::default()
- },
- )
- })
- .into_any_element()
- }
-}
-
-const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
-
-fn dismissed_rate_limit_notice() -> bool {
- db::kvp::KEY_VALUE_STORE
- .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
- .log_err()
- .map_or(false, |s| s.is_some())
-}
-
-fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
- db::write_and_log(cx, move || async move {
- if is_dismissed {
- db::kvp::KEY_VALUE_STORE
- .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
- .await
- } else {
- db::kvp::KEY_VALUE_STORE
- .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
- .await
- }
- })
-}
-
pub struct InlineAssist {
group_id: InlineAssistGroupId,
range: Range<Anchor>,
editor: WeakView<Editor>,
decorations: Option<InlineAssistDecorations>,
- codegen: Model<Codegen>,
+ codegen: Model<BufferCodegen>,
_subscriptions: Vec<Subscription>,
workspace: WeakView<Workspace>,
}
@@ -2174,11 +1462,11 @@ impl InlineAssist {
assist_id: InlineAssistId,
group_id: InlineAssistGroupId,
editor: &View<Editor>,
- prompt_editor: &View<PromptEditor>,
+ prompt_editor: &View<PromptEditor<BufferCodegen>>,
prompt_block_id: CustomBlockId,
end_block_id: CustomBlockId,
range: Range<Anchor>,
- codegen: Model<Codegen>,
+ codegen: Model<BufferCodegen>,
workspace: WeakView<Workspace>,
cx: &mut WindowContext,
) -> Self {
@@ -2273,1060 +1561,55 @@ impl InlineAssist {
struct InlineAssistDecorations {
prompt_block_id: CustomBlockId,
- prompt_editor: View<PromptEditor>,
+ prompt_editor: View<PromptEditor<BufferCodegen>>,
removed_line_block_ids: HashSet<CustomBlockId>,
end_block_id: CustomBlockId,
}
-#[derive(Copy, Clone, Debug)]
-pub enum CodegenEvent {
- Finished,
- Undone,
-}
-
-pub struct Codegen {
- alternatives: Vec<Model<CodegenAlternative>>,
- active_alternative: usize,
- seen_alternatives: HashSet<usize>,
- subscriptions: Vec<Subscription>,
- buffer: Model<MultiBuffer>,
- range: Range<Anchor>,
- initial_transaction_id: Option<TransactionId>,
- context_store: Model<ContextStore>,
- telemetry: Arc<Telemetry>,
- builder: Arc<PromptBuilder>,
- is_insertion: bool,
+struct AssistantCodeActionProvider {
+ editor: WeakView<Editor>,
+ workspace: WeakView<Workspace>,
+ thread_store: Option<WeakModel<ThreadStore>>,
}
-impl Codegen {
- pub fn new(
- buffer: Model<MultiBuffer>,
- range: Range<Anchor>,
- initial_transaction_id: Option<TransactionId>,
- context_store: Model<ContextStore>,
- telemetry: Arc<Telemetry>,
- builder: Arc<PromptBuilder>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
- let codegen = cx.new_model(|cx| {
- CodegenAlternative::new(
- buffer.clone(),
- range.clone(),
- false,
- Some(context_store.clone()),
- Some(telemetry.clone()),
- builder.clone(),
- cx,
- )
- });
- let mut this = Self {
- is_insertion: range.to_offset(&buffer.read(cx).snapshot(cx)).is_empty(),
- alternatives: vec![codegen],
- active_alternative: 0,
- seen_alternatives: HashSet::default(),
- subscriptions: Vec::new(),
- buffer,
- range,
- initial_transaction_id,
- context_store,
- telemetry,
- builder,
- };
- this.activate(0, cx);
- this
- }
+impl CodeActionProvider for AssistantCodeActionProvider {
+ fn code_actions(
+ &self,
+ buffer: &Model<Buffer>,
+ range: Range<text::Anchor>,
+ cx: &mut WindowContext,
+ ) -> Task<Result<Vec<CodeAction>>> {
+ if !AssistantSettings::get_global(cx).enabled {
+ return Task::ready(Ok(Vec::new()));
+ }
- fn subscribe_to_alternative(&mut self, cx: &mut ModelContext<Self>) {
- let codegen = self.active_alternative().clone();
- self.subscriptions.clear();
- self.subscriptions
- .push(cx.observe(&codegen, |_, _, cx| cx.notify()));
- self.subscriptions
- .push(cx.subscribe(&codegen, |_, _, event, cx| cx.emit(*event)));
- }
+ let snapshot = buffer.read(cx).snapshot();
+ let mut range = range.to_point(&snapshot);
- fn active_alternative(&self) -> &Model<CodegenAlternative> {
- &self.alternatives[self.active_alternative]
- }
+ // Expand the range to line boundaries.
+ range.start.column = 0;
+ range.end.column = snapshot.line_len(range.end.row);
- fn status<'a>(&self, cx: &'a AppContext) -> &'a CodegenStatus {
- &self.active_alternative().read(cx).status
- }
+ let mut has_diagnostics = false;
+ for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
+ range.start = cmp::min(range.start, diagnostic.range.start);
+ range.end = cmp::max(range.end, diagnostic.range.end);
+ has_diagnostics = true;
+ }
+ if has_diagnostics {
+ if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
+ if let Some(symbol) = symbols_containing_start.last() {
+ range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
+ range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
+ }
+ }
- fn alternative_count(&self, cx: &AppContext) -> usize {
- LanguageModelRegistry::read_global(cx)
- .inline_alternative_models()
- .len()
- + 1
- }
-
- pub fn cycle_prev(&mut self, cx: &mut ModelContext<Self>) {
- let next_active_ix = if self.active_alternative == 0 {
- self.alternatives.len() - 1
- } else {
- self.active_alternative - 1
- };
- self.activate(next_active_ix, cx);
- }
-
- pub fn cycle_next(&mut self, cx: &mut ModelContext<Self>) {
- let next_active_ix = (self.active_alternative + 1) % self.alternatives.len();
- self.activate(next_active_ix, cx);
- }
-
- fn activate(&mut self, index: usize, cx: &mut ModelContext<Self>) {
- self.active_alternative()
- .update(cx, |codegen, cx| codegen.set_active(false, cx));
- self.seen_alternatives.insert(index);
- self.active_alternative = index;
- self.active_alternative()
- .update(cx, |codegen, cx| codegen.set_active(true, cx));
- self.subscribe_to_alternative(cx);
- cx.notify();
- }
-
- pub fn start(&mut self, user_prompt: String, cx: &mut ModelContext<Self>) -> Result<()> {
- let alternative_models = LanguageModelRegistry::read_global(cx)
- .inline_alternative_models()
- .to_vec();
-
- self.active_alternative()
- .update(cx, |alternative, cx| alternative.undo(cx));
- self.activate(0, cx);
- self.alternatives.truncate(1);
-
- for _ in 0..alternative_models.len() {
- self.alternatives.push(cx.new_model(|cx| {
- CodegenAlternative::new(
- self.buffer.clone(),
- self.range.clone(),
- false,
- Some(self.context_store.clone()),
- Some(self.telemetry.clone()),
- self.builder.clone(),
- cx,
- )
- }));
- }
-
- let primary_model = LanguageModelRegistry::read_global(cx)
- .active_model()
- .context("no active model")?;
-
- for (model, alternative) in iter::once(primary_model)
- .chain(alternative_models)
- .zip(&self.alternatives)
- {
- alternative.update(cx, |alternative, cx| {
- alternative.start(user_prompt.clone(), model.clone(), cx)
- })?;
- }
-
- Ok(())
- }
-
- pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
- for codegen in &self.alternatives {
- codegen.update(cx, |codegen, cx| codegen.stop(cx));
- }
- }
-
- pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
- self.active_alternative()
- .update(cx, |codegen, cx| codegen.undo(cx));
-
- self.buffer.update(cx, |buffer, cx| {
- if let Some(transaction_id) = self.initial_transaction_id.take() {
- buffer.undo_transaction(transaction_id, cx);
- buffer.refresh_preview(cx);
- }
- });
- }
-
- pub fn buffer(&self, cx: &AppContext) -> Model<MultiBuffer> {
- self.active_alternative().read(cx).buffer.clone()
- }
-
- pub fn old_buffer(&self, cx: &AppContext) -> Model<Buffer> {
- self.active_alternative().read(cx).old_buffer.clone()
- }
-
- pub fn snapshot(&self, cx: &AppContext) -> MultiBufferSnapshot {
- self.active_alternative().read(cx).snapshot.clone()
- }
-
- pub fn edit_position(&self, cx: &AppContext) -> Option<Anchor> {
- self.active_alternative().read(cx).edit_position
- }
-
- fn diff<'a>(&self, cx: &'a AppContext) -> &'a Diff {
- &self.active_alternative().read(cx).diff
- }
-
- pub fn last_equal_ranges<'a>(&self, cx: &'a AppContext) -> &'a [Range<Anchor>] {
- self.active_alternative().read(cx).last_equal_ranges()
- }
-}
-
-impl EventEmitter<CodegenEvent> for Codegen {}
-
-pub struct CodegenAlternative {
- buffer: Model<MultiBuffer>,
- old_buffer: Model<Buffer>,
- snapshot: MultiBufferSnapshot,
- edit_position: Option<Anchor>,
- range: Range<Anchor>,
- last_equal_ranges: Vec<Range<Anchor>>,
- transformation_transaction_id: Option<TransactionId>,
- status: CodegenStatus,
- generation: Task<()>,
- diff: Diff,
- context_store: Option<Model<ContextStore>>,
- telemetry: Option<Arc<Telemetry>>,
- _subscription: gpui::Subscription,
- builder: Arc<PromptBuilder>,
- active: bool,
- edits: Vec<(Range<Anchor>, String)>,
- line_operations: Vec<LineOperation>,
- request: Option<LanguageModelRequest>,
- elapsed_time: Option<f64>,
- completion: Option<String>,
- message_id: Option<String>,
-}
-
-#[derive(Default)]
-struct Diff {
- deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)>,
- inserted_row_ranges: Vec<Range<Anchor>>,
-}
-
-impl Diff {
- fn is_empty(&self) -> bool {
- self.deleted_row_ranges.is_empty() && self.inserted_row_ranges.is_empty()
- }
-}
-
-impl EventEmitter<CodegenEvent> for CodegenAlternative {}
-
-impl CodegenAlternative {
- pub fn new(
- buffer: Model<MultiBuffer>,
- range: Range<Anchor>,
- active: bool,
- context_store: Option<Model<ContextStore>>,
- telemetry: Option<Arc<Telemetry>>,
- builder: Arc<PromptBuilder>,
- cx: &mut ModelContext<Self>,
- ) -> Self {
- let snapshot = buffer.read(cx).snapshot(cx);
-
- let (old_buffer, _, _) = buffer
- .read(cx)
- .range_to_buffer_ranges(range.clone(), cx)
- .pop()
- .unwrap();
- let old_buffer = cx.new_model(|cx| {
- let old_buffer = old_buffer.read(cx);
- let text = old_buffer.as_rope().clone();
- let line_ending = old_buffer.line_ending();
- let language = old_buffer.language().cloned();
- let language_registry = old_buffer.language_registry();
-
- let mut buffer = Buffer::local_normalized(text, line_ending, cx);
- buffer.set_language(language, cx);
- if let Some(language_registry) = language_registry {
- buffer.set_language_registry(language_registry)
- }
- buffer
- });
-
- Self {
- buffer: buffer.clone(),
- old_buffer,
- edit_position: None,
- message_id: None,
- snapshot,
- last_equal_ranges: Default::default(),
- transformation_transaction_id: None,
- status: CodegenStatus::Idle,
- generation: Task::ready(()),
- diff: Diff::default(),
- context_store,
- telemetry,
- _subscription: cx.subscribe(&buffer, Self::handle_buffer_event),
- builder,
- active,
- edits: Vec::new(),
- line_operations: Vec::new(),
- range,
- request: None,
- elapsed_time: None,
- completion: None,
- }
- }
-
- fn set_active(&mut self, active: bool, cx: &mut ModelContext<Self>) {
- if active != self.active {
- self.active = active;
-
- if self.active {
- let edits = self.edits.clone();
- self.apply_edits(edits, cx);
- if matches!(self.status, CodegenStatus::Pending) {
- let line_operations = self.line_operations.clone();
- self.reapply_line_based_diff(line_operations, cx);
- } else {
- self.reapply_batch_diff(cx).detach();
- }
- } else if let Some(transaction_id) = self.transformation_transaction_id.take() {
- self.buffer.update(cx, |buffer, cx| {
- buffer.undo_transaction(transaction_id, cx);
- buffer.forget_transaction(transaction_id, cx);
- });
- }
- }
- }
-
- fn handle_buffer_event(
- &mut self,
- _buffer: Model<MultiBuffer>,
- event: &multi_buffer::Event,
- cx: &mut ModelContext<Self>,
- ) {
- if let multi_buffer::Event::TransactionUndone { transaction_id } = event {
- if self.transformation_transaction_id == Some(*transaction_id) {
- self.transformation_transaction_id = None;
- self.generation = Task::ready(());
- cx.emit(CodegenEvent::Undone);
- }
- }
- }
-
- pub fn last_equal_ranges(&self) -> &[Range<Anchor>] {
- &self.last_equal_ranges
- }
-
- pub fn start(
- &mut self,
- user_prompt: String,
- model: Arc<dyn LanguageModel>,
- cx: &mut ModelContext<Self>,
- ) -> Result<()> {
- if let Some(transformation_transaction_id) = self.transformation_transaction_id.take() {
- self.buffer.update(cx, |buffer, cx| {
- buffer.undo_transaction(transformation_transaction_id, cx);
- });
- }
-
- self.edit_position = Some(self.range.start.bias_right(&self.snapshot));
-
- let api_key = model.api_key(cx);
- let telemetry_id = model.telemetry_id();
- let provider_id = model.provider_id();
- let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
- if user_prompt.trim().to_lowercase() == "delete" {
- async { Ok(LanguageModelTextStream::default()) }.boxed_local()
- } else {
- let request = self.build_request(user_prompt, cx)?;
- self.request = Some(request.clone());
-
- cx.spawn(|_, cx| async move { model.stream_completion_text(request, &cx).await })
- .boxed_local()
- };
- self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
- Ok(())
- }
-
- fn build_request(
- &self,
- user_prompt: String,
- cx: &mut AppContext,
- ) -> Result<LanguageModelRequest> {
- let buffer = self.buffer.read(cx).snapshot(cx);
- let language = buffer.language_at(self.range.start);
- let language_name = if let Some(language) = language.as_ref() {
- if Arc::ptr_eq(language, &language::PLAIN_TEXT) {
- None
- } else {
- Some(language.name())
- }
- } else {
- None
- };
-
- let language_name = language_name.as_ref();
- let start = buffer.point_to_buffer_offset(self.range.start);
- let end = buffer.point_to_buffer_offset(self.range.end);
- let (buffer, range) = if let Some((start, end)) = start.zip(end) {
- let (start_buffer, start_buffer_offset) = start;
- let (end_buffer, end_buffer_offset) = end;
- if start_buffer.remote_id() == end_buffer.remote_id() {
- (start_buffer.clone(), start_buffer_offset..end_buffer_offset)
- } else {
- return Err(anyhow::anyhow!("invalid transformation range"));
- }
- } else {
- return Err(anyhow::anyhow!("invalid transformation range"));
- };
-
- let prompt = self
- .builder
- .generate_inline_transformation_prompt(user_prompt, language_name, buffer, range)
- .map_err(|e| anyhow::anyhow!("Failed to generate content prompt: {}", e))?;
-
- let mut request_message = LanguageModelRequestMessage {
- role: Role::User,
- content: Vec::new(),
- cache: false,
- };
-
- if let Some(context_store) = &self.context_store {
- let context = context_store.update(cx, |this, _cx| this.context().clone());
- attach_context_to_message(&mut request_message, context);
- }
-
- request_message.content.push(prompt.into());
-
- Ok(LanguageModelRequest {
- tools: Vec::new(),
- stop: Vec::new(),
- temperature: None,
- messages: vec![request_message],
- })
- }
-
- pub fn handle_stream(
- &mut self,
- model_telemetry_id: String,
- model_provider_id: String,
- model_api_key: Option<String>,
- stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
- cx: &mut ModelContext<Self>,
- ) {
- let start_time = Instant::now();
- let snapshot = self.snapshot.clone();
- let selected_text = snapshot
- .text_for_range(self.range.start..self.range.end)
- .collect::<Rope>();
-
- let selection_start = self.range.start.to_point(&snapshot);
-
- // Start with the indentation of the first line in the selection
- let mut suggested_line_indent = snapshot
- .suggested_indents(selection_start.row..=selection_start.row, cx)
- .into_values()
- .next()
- .unwrap_or_else(|| snapshot.indent_size_for_line(MultiBufferRow(selection_start.row)));
-
- // If the first line in the selection does not have indentation, check the following lines
- if suggested_line_indent.len == 0 && suggested_line_indent.kind == IndentKind::Space {
- for row in selection_start.row..=self.range.end.to_point(&snapshot).row {
- let line_indent = snapshot.indent_size_for_line(MultiBufferRow(row));
- // Prefer tabs if a line in the selection uses tabs as indentation
- if line_indent.kind == IndentKind::Tab {
- suggested_line_indent.kind = IndentKind::Tab;
- break;
- }
- }
- }
-
- let http_client = cx.http_client().clone();
- let telemetry = self.telemetry.clone();
- let language_name = {
- let multibuffer = self.buffer.read(cx);
- let ranges = multibuffer.range_to_buffer_ranges(self.range.clone(), cx);
- ranges
- .first()
- .and_then(|(buffer, _, _)| buffer.read(cx).language())
- .map(|language| language.name())
- };
-
- self.diff = Diff::default();
- self.status = CodegenStatus::Pending;
- let mut edit_start = self.range.start.to_offset(&snapshot);
- let completion = Arc::new(Mutex::new(String::new()));
- let completion_clone = completion.clone();
-
- self.generation = cx.spawn(|codegen, mut cx| {
- async move {
- let stream = stream.await;
- let message_id = stream
- .as_ref()
- .ok()
- .and_then(|stream| stream.message_id.clone());
- let generate = async {
- let (mut diff_tx, mut diff_rx) = mpsc::channel(1);
- let executor = cx.background_executor().clone();
- let message_id = message_id.clone();
- let line_based_stream_diff: Task<anyhow::Result<()>> =
- cx.background_executor().spawn(async move {
- let mut response_latency = None;
- let request_start = Instant::now();
- let diff = async {
- let chunks = StripInvalidSpans::new(stream?.stream);
- futures::pin_mut!(chunks);
- let mut diff = StreamingDiff::new(selected_text.to_string());
- let mut line_diff = LineDiff::default();
-
- let mut new_text = String::new();
- let mut base_indent = None;
- let mut line_indent = None;
- let mut first_line = true;
-
- while let Some(chunk) = chunks.next().await {
- if response_latency.is_none() {
- response_latency = Some(request_start.elapsed());
- }
- let chunk = chunk?;
- completion_clone.lock().push_str(&chunk);
-
- let mut lines = chunk.split('\n').peekable();
- while let Some(line) = lines.next() {
- new_text.push_str(line);
- if line_indent.is_none() {
- if let Some(non_whitespace_ch_ix) =
- new_text.find(|ch: char| !ch.is_whitespace())
- {
- line_indent = Some(non_whitespace_ch_ix);
- base_indent = base_indent.or(line_indent);
-
- let line_indent = line_indent.unwrap();
- let base_indent = base_indent.unwrap();
- let indent_delta =
- line_indent as i32 - base_indent as i32;
- let mut corrected_indent_len = cmp::max(
- 0,
- suggested_line_indent.len as i32 + indent_delta,
- )
- as usize;
- if first_line {
- corrected_indent_len = corrected_indent_len
- .saturating_sub(
- selection_start.column as usize,
- );
- }
-
- let indent_char = suggested_line_indent.char();
- let mut indent_buffer = [0; 4];
- let indent_str =
- indent_char.encode_utf8(&mut indent_buffer);
- new_text.replace_range(
- ..line_indent,
- &indent_str.repeat(corrected_indent_len),
- );
- }
- }
-
- if line_indent.is_some() {
- let char_ops = diff.push_new(&new_text);
- line_diff
- .push_char_operations(&char_ops, &selected_text);
- diff_tx
- .send((char_ops, line_diff.line_operations()))
- .await?;
- new_text.clear();
- }
-
- if lines.peek().is_some() {
- let char_ops = diff.push_new("\n");
- line_diff
- .push_char_operations(&char_ops, &selected_text);
- diff_tx
- .send((char_ops, line_diff.line_operations()))
- .await?;
- if line_indent.is_none() {
- // Don't write out the leading indentation in empty lines on the next line
- // This is the case where the above if statement didn't clear the buffer
- new_text.clear();
- }
- line_indent = None;
- first_line = false;
- }
- }
- }
-
- let mut char_ops = diff.push_new(&new_text);
- char_ops.extend(diff.finish());
- line_diff.push_char_operations(&char_ops, &selected_text);
- line_diff.finish(&selected_text);
- diff_tx
- .send((char_ops, line_diff.line_operations()))
- .await?;
-
- anyhow::Ok(())
- };
-
- let result = diff.await;
-
- let error_message =
- result.as_ref().err().map(|error| error.to_string());
- report_assistant_event(
- AssistantEvent {
- conversation_id: None,
- message_id,
- kind: AssistantKind::Inline,
- phase: AssistantPhase::Response,
- model: model_telemetry_id,
- model_provider: model_provider_id.to_string(),
- response_latency,
- error_message,
- language_name: language_name.map(|name| name.to_proto()),
- },
- telemetry,
- http_client,
- model_api_key,
- &executor,
- );
-
- result?;
- Ok(())
- });
-
- while let Some((char_ops, line_ops)) = diff_rx.next().await {
- codegen.update(&mut cx, |codegen, cx| {
- codegen.last_equal_ranges.clear();
-
- let edits = char_ops
- .into_iter()
- .filter_map(|operation| match operation {
- CharOperation::Insert { text } => {
- let edit_start = snapshot.anchor_after(edit_start);
- Some((edit_start..edit_start, text))
- }
- CharOperation::Delete { bytes } => {
- let edit_end = edit_start + bytes;
- let edit_range = snapshot.anchor_after(edit_start)
- ..snapshot.anchor_before(edit_end);
- edit_start = edit_end;
- Some((edit_range, String::new()))
- }
- CharOperation::Keep { bytes } => {
- let edit_end = edit_start + bytes;
- let edit_range = snapshot.anchor_after(edit_start)
- ..snapshot.anchor_before(edit_end);
- edit_start = edit_end;
- codegen.last_equal_ranges.push(edit_range);
- None
- }
- })
- .collect::<Vec<_>>();
-
- if codegen.active {
- codegen.apply_edits(edits.iter().cloned(), cx);
- codegen.reapply_line_based_diff(line_ops.iter().cloned(), cx);
- }
- codegen.edits.extend(edits);
- codegen.line_operations = line_ops;
- codegen.edit_position = Some(snapshot.anchor_after(edit_start));
-
- cx.notify();
- })?;
- }
-
- // Streaming stopped and we have the new text in the buffer, and a line-based diff applied for the whole new buffer.
- // That diff is not what a regular diff is and might look unexpected, ergo apply a regular diff.
- // It's fine to apply even if the rest of the line diffing fails, as no more hunks are coming through `diff_rx`.
- let batch_diff_task =
- codegen.update(&mut cx, |codegen, cx| codegen.reapply_batch_diff(cx))?;
- let (line_based_stream_diff, ()) =
- join!(line_based_stream_diff, batch_diff_task);
- line_based_stream_diff?;
-
- anyhow::Ok(())
- };
-
- let result = generate.await;
- let elapsed_time = start_time.elapsed().as_secs_f64();
-
- codegen
- .update(&mut cx, |this, cx| {
- this.message_id = message_id;
- this.last_equal_ranges.clear();
- if let Err(error) = result {
- this.status = CodegenStatus::Error(error);
- } else {
- this.status = CodegenStatus::Done;
- }
- this.elapsed_time = Some(elapsed_time);
- this.completion = Some(completion.lock().clone());
- cx.emit(CodegenEvent::Finished);
- cx.notify();
- })
- .ok();
- }
- });
- cx.notify();
- }
-
- pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
- self.last_equal_ranges.clear();
- if self.diff.is_empty() {
- self.status = CodegenStatus::Idle;
- } else {
- self.status = CodegenStatus::Done;
- }
- self.generation = Task::ready(());
- cx.emit(CodegenEvent::Finished);
- cx.notify();
- }
-
- pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
- self.buffer.update(cx, |buffer, cx| {
- if let Some(transaction_id) = self.transformation_transaction_id.take() {
- buffer.undo_transaction(transaction_id, cx);
- buffer.refresh_preview(cx);
- }
- });
- }
-
- fn apply_edits(
- &mut self,
- edits: impl IntoIterator<Item = (Range<Anchor>, String)>,
- cx: &mut ModelContext<CodegenAlternative>,
- ) {
- let transaction = self.buffer.update(cx, |buffer, cx| {
- // Avoid grouping assistant edits with user edits.
- buffer.finalize_last_transaction(cx);
- buffer.start_transaction(cx);
- buffer.edit(edits, None, cx);
- buffer.end_transaction(cx)
- });
-
- if let Some(transaction) = transaction {
- if let Some(first_transaction) = self.transformation_transaction_id {
- // Group all assistant edits into the first transaction.
- self.buffer.update(cx, |buffer, cx| {
- buffer.merge_transactions(transaction, first_transaction, cx)
- });
- } else {
- self.transformation_transaction_id = Some(transaction);
- self.buffer
- .update(cx, |buffer, cx| buffer.finalize_last_transaction(cx));
- }
- }
- }
-
- fn reapply_line_based_diff(
- &mut self,
- line_operations: impl IntoIterator<Item = LineOperation>,
- cx: &mut ModelContext<Self>,
- ) {
- let old_snapshot = self.snapshot.clone();
- let old_range = self.range.to_point(&old_snapshot);
- let new_snapshot = self.buffer.read(cx).snapshot(cx);
- let new_range = self.range.to_point(&new_snapshot);
-
- let mut old_row = old_range.start.row;
- let mut new_row = new_range.start.row;
-
- self.diff.deleted_row_ranges.clear();
- self.diff.inserted_row_ranges.clear();
- for operation in line_operations {
- match operation {
- LineOperation::Keep { lines } => {
- old_row += lines;
- new_row += lines;
- }
- LineOperation::Delete { lines } => {
- let old_end_row = old_row + lines - 1;
- let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
-
- if let Some((_, last_deleted_row_range)) =
- self.diff.deleted_row_ranges.last_mut()
- {
- if *last_deleted_row_range.end() + 1 == old_row {
- *last_deleted_row_range = *last_deleted_row_range.start()..=old_end_row;
- } else {
- self.diff
- .deleted_row_ranges
- .push((new_row, old_row..=old_end_row));
- }
- } else {
- self.diff
- .deleted_row_ranges
- .push((new_row, old_row..=old_end_row));
- }
-
- old_row += lines;
- }
- LineOperation::Insert { lines } => {
- let new_end_row = new_row + lines - 1;
- let start = new_snapshot.anchor_before(Point::new(new_row, 0));
- let end = new_snapshot.anchor_before(Point::new(
- new_end_row,
- new_snapshot.line_len(MultiBufferRow(new_end_row)),
- ));
- self.diff.inserted_row_ranges.push(start..end);
- new_row += lines;
- }
- }
-
- cx.notify();
- }
- }
-
- fn reapply_batch_diff(&mut self, cx: &mut ModelContext<Self>) -> Task<()> {
- let old_snapshot = self.snapshot.clone();
- let old_range = self.range.to_point(&old_snapshot);
- let new_snapshot = self.buffer.read(cx).snapshot(cx);
- let new_range = self.range.to_point(&new_snapshot);
-
- cx.spawn(|codegen, mut cx| async move {
- let (deleted_row_ranges, inserted_row_ranges) = cx
- .background_executor()
- .spawn(async move {
- let old_text = old_snapshot
- .text_for_range(
- Point::new(old_range.start.row, 0)
- ..Point::new(
- old_range.end.row,
- old_snapshot.line_len(MultiBufferRow(old_range.end.row)),
- ),
- )
- .collect::<String>();
- let new_text = new_snapshot
- .text_for_range(
- Point::new(new_range.start.row, 0)
- ..Point::new(
- new_range.end.row,
- new_snapshot.line_len(MultiBufferRow(new_range.end.row)),
- ),
- )
- .collect::<String>();
-
- let mut old_row = old_range.start.row;
- let mut new_row = new_range.start.row;
- let batch_diff =
- similar::TextDiff::from_lines(old_text.as_str(), new_text.as_str());
-
- let mut deleted_row_ranges: Vec<(Anchor, RangeInclusive<u32>)> = Vec::new();
- let mut inserted_row_ranges = Vec::new();
- for change in batch_diff.iter_all_changes() {
- let line_count = change.value().lines().count() as u32;
- match change.tag() {
- similar::ChangeTag::Equal => {
- old_row += line_count;
- new_row += line_count;
- }
- similar::ChangeTag::Delete => {
- let old_end_row = old_row + line_count - 1;
- let new_row = new_snapshot.anchor_before(Point::new(new_row, 0));
-
- if let Some((_, last_deleted_row_range)) =
- deleted_row_ranges.last_mut()
- {
- if *last_deleted_row_range.end() + 1 == old_row {
- *last_deleted_row_range =
- *last_deleted_row_range.start()..=old_end_row;
- } else {
- deleted_row_ranges.push((new_row, old_row..=old_end_row));
- }
- } else {
- deleted_row_ranges.push((new_row, old_row..=old_end_row));
- }
-
- old_row += line_count;
- }
- similar::ChangeTag::Insert => {
- let new_end_row = new_row + line_count - 1;
- let start = new_snapshot.anchor_before(Point::new(new_row, 0));
- let end = new_snapshot.anchor_before(Point::new(
- new_end_row,
- new_snapshot.line_len(MultiBufferRow(new_end_row)),
- ));
- inserted_row_ranges.push(start..end);
- new_row += line_count;
- }
- }
- }
-
- (deleted_row_ranges, inserted_row_ranges)
- })
- .await;
-
- codegen
- .update(&mut cx, |codegen, cx| {
- codegen.diff.deleted_row_ranges = deleted_row_ranges;
- codegen.diff.inserted_row_ranges = inserted_row_ranges;
- cx.notify();
- })
- .ok();
- })
- }
-}
-
-struct StripInvalidSpans<T> {
- stream: T,
- stream_done: bool,
- buffer: String,
- first_line: bool,
- line_end: bool,
- starts_with_code_block: bool,
-}
-
-impl<T> StripInvalidSpans<T>
-where
- T: Stream<Item = Result<String>>,
-{
- fn new(stream: T) -> Self {
- Self {
- stream,
- stream_done: false,
- buffer: String::new(),
- first_line: true,
- line_end: false,
- starts_with_code_block: false,
- }
- }
-}
-
-impl<T> Stream for StripInvalidSpans<T>
-where
- T: Stream<Item = Result<String>>,
-{
- type Item = Result<String>;
-
- fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Option<Self::Item>> {
- const CODE_BLOCK_DELIMITER: &str = "```";
- const CURSOR_SPAN: &str = "<|CURSOR|>";
-
- let this = unsafe { self.get_unchecked_mut() };
- loop {
- if !this.stream_done {
- let mut stream = unsafe { Pin::new_unchecked(&mut this.stream) };
- match stream.as_mut().poll_next(cx) {
- Poll::Ready(Some(Ok(chunk))) => {
- this.buffer.push_str(&chunk);
- }
- Poll::Ready(Some(Err(error))) => return Poll::Ready(Some(Err(error))),
- Poll::Ready(None) => {
- this.stream_done = true;
- }
- Poll::Pending => return Poll::Pending,
- }
- }
-
- let mut chunk = String::new();
- let mut consumed = 0;
- if !this.buffer.is_empty() {
- let mut lines = this.buffer.split('\n').enumerate().peekable();
- while let Some((line_ix, line)) = lines.next() {
- if line_ix > 0 {
- this.first_line = false;
- }
-
- if this.first_line {
- let trimmed_line = line.trim();
- if lines.peek().is_some() {
- if trimmed_line.starts_with(CODE_BLOCK_DELIMITER) {
- consumed += line.len() + 1;
- this.starts_with_code_block = true;
- continue;
- }
- } else if trimmed_line.is_empty()
- || prefixes(CODE_BLOCK_DELIMITER)
- .any(|prefix| trimmed_line.starts_with(prefix))
- {
- break;
- }
- }
-
- let line_without_cursor = line.replace(CURSOR_SPAN, "");
- if lines.peek().is_some() {
- if this.line_end {
- chunk.push('\n');
- }
-
- chunk.push_str(&line_without_cursor);
- this.line_end = true;
- consumed += line.len() + 1;
- } else if this.stream_done {
- if !this.starts_with_code_block
- || !line_without_cursor.trim().ends_with(CODE_BLOCK_DELIMITER)
- {
- if this.line_end {
- chunk.push('\n');
- }
-
- chunk.push_str(&line);
- }
-
- consumed += line.len();
- } else {
- let trimmed_line = line.trim();
- if trimmed_line.is_empty()
- || prefixes(CURSOR_SPAN).any(|prefix| trimmed_line.ends_with(prefix))
- || prefixes(CODE_BLOCK_DELIMITER)
- .any(|prefix| trimmed_line.ends_with(prefix))
- {
- break;
- } else {
- if this.line_end {
- chunk.push('\n');
- this.line_end = false;
- }
-
- chunk.push_str(&line_without_cursor);
- consumed += line.len();
- }
- }
- }
- }
-
- this.buffer = this.buffer.split_off(consumed);
- if !chunk.is_empty() {
- return Poll::Ready(Some(Ok(chunk)));
- } else if this.stream_done {
- return Poll::Ready(None);
- }
- }
- }
-}
-
-struct AssistantCodeActionProvider {
- editor: WeakView<Editor>,
- workspace: WeakView<Workspace>,
- thread_store: Option<WeakModel<ThreadStore>>,
-}
-
-impl CodeActionProvider for AssistantCodeActionProvider {
- fn code_actions(
- &self,
- buffer: &Model<Buffer>,
- range: Range<text::Anchor>,
- cx: &mut WindowContext,
- ) -> Task<Result<Vec<CodeAction>>> {
- if !AssistantSettings::get_global(cx).enabled {
- return Task::ready(Ok(Vec::new()));
- }
-
- let snapshot = buffer.read(cx).snapshot();
- let mut range = range.to_point(&snapshot);
-
- // Expand the range to line boundaries.
- range.start.column = 0;
- range.end.column = snapshot.line_len(range.end.row);
-
- let mut has_diagnostics = false;
- for diagnostic in snapshot.diagnostics_in_range::<_, Point>(range.clone(), false) {
- range.start = cmp::min(range.start, diagnostic.range.start);
- range.end = cmp::max(range.end, diagnostic.range.end);
- has_diagnostics = true;
- }
- if has_diagnostics {
- if let Some(symbols_containing_start) = snapshot.symbols_containing(range.start, None) {
- if let Some(symbol) = symbols_containing_start.last() {
- range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
- range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
- }
- }
-
- if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
- if let Some(symbol) = symbols_containing_end.last() {
- range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
- range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
- }
- }
+ if let Some(symbols_containing_end) = snapshot.symbols_containing(range.end, None) {
+ if let Some(symbol) = symbols_containing_end.last() {
+ range.start = cmp::min(range.start, symbol.range.start.to_point(&snapshot));
+ range.end = cmp::max(range.end, symbol.range.end.to_point(&snapshot));
+ }
+ }
Task::ready(Ok(vec![CodeAction {
server_id: language::LanguageServerId(0),
@@ -1,5 +1,1068 @@
-use gpui::{AnyElement, EventEmitter};
-use ui::{prelude::*, IconButtonShape, Tooltip};
+use crate::buffer_codegen::BufferCodegen;
+use crate::context_picker::ContextPicker;
+use crate::context_store::ContextStore;
+use crate::context_strip::ContextStrip;
+use crate::terminal_codegen::TerminalCodegen;
+use crate::thread_store::ThreadStore;
+use crate::ToggleContextPicker;
+use crate::{
+ assistant_settings::AssistantSettings, CycleNextInlineAssist, CyclePreviousInlineAssist,
+};
+use client::ErrorExt;
+use collections::VecDeque;
+use editor::{
+ actions::{MoveDown, MoveUp},
+ Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, GutterDimensions, MultiBuffer,
+};
+use feature_flags::{FeatureFlagAppExt as _, ZedPro};
+use fs::Fs;
+use gpui::{
+ anchored, deferred, point, AnyElement, AppContext, ClickEvent, CursorStyle, EventEmitter,
+ FocusHandle, FocusableView, FontWeight, Model, Subscription, TextStyle, View, ViewContext,
+ WeakModel, WeakView, WindowContext,
+};
+use language_model::{LanguageModel, LanguageModelRegistry};
+use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
+use parking_lot::Mutex;
+use settings::{update_settings_file, Settings};
+use std::cmp;
+use std::sync::Arc;
+use theme::ThemeSettings;
+use ui::{
+ prelude::*, CheckboxWithLabel, IconButtonShape, KeyBinding, Popover, PopoverMenuHandle, Tooltip,
+};
+use util::ResultExt;
+use workspace::Workspace;
+
+pub struct PromptEditor<T> {
+ pub editor: View<Editor>,
+ mode: PromptEditorMode,
+ context_strip: View<ContextStrip>,
+ context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
+ language_model_selector: View<LanguageModelSelector>,
+ edited_since_done: bool,
+ prompt_history: VecDeque<String>,
+ prompt_history_ix: Option<usize>,
+ pending_prompt: String,
+ _codegen_subscription: Subscription,
+ editor_subscriptions: Vec<Subscription>,
+ show_rate_limit_notice: bool,
+ _phantom: std::marker::PhantomData<T>,
+}
+
+impl<T: 'static> EventEmitter<PromptEditorEvent> for PromptEditor<T> {}
+
+impl<T: 'static> Render for PromptEditor<T> {
+ fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let mut buttons = Vec::new();
+
+ let spacing = match &self.mode {
+ PromptEditorMode::Buffer {
+ id: _,
+ codegen,
+ gutter_dimensions,
+ } => {
+ let codegen = codegen.read(cx);
+
+ if codegen.alternative_count(cx) > 1 {
+ buttons.push(self.render_cycle_controls(&codegen, cx));
+ }
+
+ let gutter_dimensions = gutter_dimensions.lock();
+
+ gutter_dimensions.full_width() + (gutter_dimensions.margin / 2.0)
+ }
+ PromptEditorMode::Terminal { .. } => Pixels::ZERO,
+ };
+
+ buttons.extend(self.render_buttons(cx));
+
+ v_flex()
+ .border_y_1()
+ .border_color(cx.theme().status().info_border)
+ .size_full()
+ .py(cx.line_height() / 2.5)
+ .child(
+ h_flex()
+ .key_context("PromptEditor")
+ .bg(cx.theme().colors().editor_background)
+ .block_mouse_down()
+ .cursor(CursorStyle::Arrow)
+ .on_action(cx.listener(Self::toggle_context_picker))
+ .on_action(cx.listener(Self::confirm))
+ .on_action(cx.listener(Self::cancel))
+ .on_action(cx.listener(Self::move_up))
+ .on_action(cx.listener(Self::move_down))
+ .capture_action(cx.listener(Self::cycle_prev))
+ .capture_action(cx.listener(Self::cycle_next))
+ .child(
+ h_flex()
+ .w(spacing)
+ .justify_center()
+ .gap_2()
+ .child(LanguageModelSelectorPopoverMenu::new(
+ self.language_model_selector.clone(),
+ IconButton::new("context", IconName::SettingsAlt)
+ .shape(IconButtonShape::Square)
+ .icon_size(IconSize::Small)
+ .icon_color(Color::Muted)
+ .tooltip(move |cx| {
+ Tooltip::with_meta(
+ format!(
+ "Using {}",
+ LanguageModelRegistry::read_global(cx)
+ .active_model()
+ .map(|model| model.name().0)
+ .unwrap_or_else(|| "No model selected".into()),
+ ),
+ None,
+ "Change Model",
+ cx,
+ )
+ }),
+ ))
+ .map(|el| {
+ let CodegenStatus::Error(error) = self.codegen_status(cx) else {
+ return el;
+ };
+
+ let error_message = SharedString::from(error.to_string());
+ if error.error_code() == proto::ErrorCode::RateLimitExceeded
+ && cx.has_flag::<ZedPro>()
+ {
+ el.child(
+ v_flex()
+ .child(
+ IconButton::new(
+ "rate-limit-error",
+ IconName::XCircle,
+ )
+ .toggle_state(self.show_rate_limit_notice)
+ .shape(IconButtonShape::Square)
+ .icon_size(IconSize::Small)
+ .on_click(
+ cx.listener(Self::toggle_rate_limit_notice),
+ ),
+ )
+ .children(self.show_rate_limit_notice.then(|| {
+ deferred(
+ anchored()
+ .position_mode(
+ gpui::AnchoredPositionMode::Local,
+ )
+ .position(point(px(0.), px(24.)))
+ .anchor(gpui::Corner::TopLeft)
+ .child(self.render_rate_limit_notice(cx)),
+ )
+ })),
+ )
+ } else {
+ el.child(
+ div()
+ .id("error")
+ .tooltip(move |cx| {
+ Tooltip::text(error_message.clone(), cx)
+ })
+ .child(
+ Icon::new(IconName::XCircle)
+ .size(IconSize::Small)
+ .color(Color::Error),
+ ),
+ )
+ }
+ }),
+ )
+ .child(div().flex_1().child(self.render_editor(cx)))
+ .child(h_flex().gap_2().pr_6().children(buttons)),
+ )
+ .child(
+ h_flex()
+ .child(h_flex().w(spacing).justify_center().gap_2())
+ .child(self.context_strip.clone()),
+ )
+ }
+}
+
+impl<T: 'static> FocusableView for PromptEditor<T> {
+ fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
+ self.editor.focus_handle(cx)
+ }
+}
+
+impl<T: 'static> PromptEditor<T> {
+ const MAX_LINES: u8 = 8;
+
+ fn codegen_status<'a>(&'a self, cx: &'a AppContext) -> &'a CodegenStatus {
+ match &self.mode {
+ PromptEditorMode::Buffer { codegen, .. } => codegen.read(cx).status(cx),
+ PromptEditorMode::Terminal { codegen, .. } => &codegen.read(cx).status,
+ }
+ }
+
+ fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
+ self.editor_subscriptions.clear();
+ self.editor_subscriptions
+ .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
+ }
+
+ pub fn set_show_cursor_when_unfocused(
+ &mut self,
+ show_cursor_when_unfocused: bool,
+ cx: &mut ViewContext<Self>,
+ ) {
+ self.editor.update(cx, |editor, cx| {
+ editor.set_show_cursor_when_unfocused(show_cursor_when_unfocused, cx)
+ });
+ }
+
+ pub fn unlink(&mut self, cx: &mut ViewContext<Self>) {
+ let prompt = self.prompt(cx);
+ let focus = self.editor.focus_handle(cx).contains_focused(cx);
+ self.editor = cx.new_view(|cx| {
+ let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
+ editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
+ editor.set_placeholder_text(Self::placeholder_text(&self.mode, cx), cx);
+ editor.set_placeholder_text("Add a promptβ¦", cx);
+ editor.set_text(prompt, cx);
+ if focus {
+ editor.focus(cx);
+ }
+ editor
+ });
+ self.subscribe_to_editor(cx);
+ }
+
+ pub fn placeholder_text(mode: &PromptEditorMode, cx: &WindowContext) -> String {
+ let action = match mode {
+ PromptEditorMode::Buffer { codegen, .. } => {
+ if codegen.read(cx).is_insertion {
+ "Generate"
+ } else {
+ "Transform"
+ }
+ }
+ PromptEditorMode::Terminal { .. } => "Generate",
+ };
+
+ let assistant_panel_keybinding = ui::text_for_action(&crate::ToggleFocus, cx)
+ .map(|keybinding| format!("{keybinding} to chat β "))
+ .unwrap_or_default();
+
+ format!("{action}β¦ ({assistant_panel_keybinding}ββ for history)")
+ }
+
+ pub fn prompt(&self, cx: &AppContext) -> String {
+ self.editor.read(cx).text(cx)
+ }
+
+ fn toggle_rate_limit_notice(&mut self, _: &ClickEvent, cx: &mut ViewContext<Self>) {
+ self.show_rate_limit_notice = !self.show_rate_limit_notice;
+ if self.show_rate_limit_notice {
+ cx.focus_view(&self.editor);
+ }
+ cx.notify();
+ }
+
+ fn handle_prompt_editor_events(
+ &mut self,
+ _: View<Editor>,
+ event: &EditorEvent,
+ cx: &mut ViewContext<Self>,
+ ) {
+ match event {
+ EditorEvent::Edited { .. } => {
+ if let Some(workspace) = cx.window_handle().downcast::<Workspace>() {
+ workspace
+ .update(cx, |workspace, cx| {
+ let is_via_ssh = workspace
+ .project()
+ .update(cx, |project, _| project.is_via_ssh());
+
+ workspace
+ .client()
+ .telemetry()
+ .log_edit_event("inline assist", is_via_ssh);
+ })
+ .log_err();
+ }
+ let prompt = self.editor.read(cx).text(cx);
+ if self
+ .prompt_history_ix
+ .map_or(true, |ix| self.prompt_history[ix] != prompt)
+ {
+ self.prompt_history_ix.take();
+ self.pending_prompt = prompt;
+ }
+
+ self.edited_since_done = true;
+ cx.notify();
+ }
+ EditorEvent::Blurred => {
+ if self.show_rate_limit_notice {
+ self.show_rate_limit_notice = false;
+ cx.notify();
+ }
+ }
+ _ => {}
+ }
+ }
+
+ fn toggle_context_picker(&mut self, _: &ToggleContextPicker, cx: &mut ViewContext<Self>) {
+ self.context_picker_menu_handle.toggle(cx);
+ }
+
+ fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
+ match self.codegen_status(cx) {
+ CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
+ cx.emit(PromptEditorEvent::CancelRequested);
+ }
+ CodegenStatus::Pending => {
+ cx.emit(PromptEditorEvent::StopRequested);
+ }
+ }
+ }
+
+ fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
+ match self.codegen_status(cx) {
+ CodegenStatus::Idle => {
+ cx.emit(PromptEditorEvent::StartRequested);
+ }
+ CodegenStatus::Pending => {
+ cx.emit(PromptEditorEvent::DismissRequested);
+ }
+ CodegenStatus::Done => {
+ if self.edited_since_done {
+ cx.emit(PromptEditorEvent::StartRequested);
+ } else {
+ cx.emit(PromptEditorEvent::ConfirmRequested { execute: false });
+ }
+ }
+ CodegenStatus::Error(_) => {
+ cx.emit(PromptEditorEvent::StartRequested);
+ }
+ }
+ }
+
+ fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
+ if let Some(ix) = self.prompt_history_ix {
+ if ix > 0 {
+ self.prompt_history_ix = Some(ix - 1);
+ let prompt = self.prompt_history[ix - 1].as_str();
+ self.editor.update(cx, |editor, cx| {
+ editor.set_text(prompt, cx);
+ editor.move_to_beginning(&Default::default(), cx);
+ });
+ }
+ } else if !self.prompt_history.is_empty() {
+ self.prompt_history_ix = Some(self.prompt_history.len() - 1);
+ let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
+ self.editor.update(cx, |editor, cx| {
+ editor.set_text(prompt, cx);
+ editor.move_to_beginning(&Default::default(), cx);
+ });
+ }
+ }
+
+ fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
+ if let Some(ix) = self.prompt_history_ix {
+ if ix < self.prompt_history.len() - 1 {
+ self.prompt_history_ix = Some(ix + 1);
+ let prompt = self.prompt_history[ix + 1].as_str();
+ self.editor.update(cx, |editor, cx| {
+ editor.set_text(prompt, cx);
+ editor.move_to_end(&Default::default(), cx)
+ });
+ } else {
+ self.prompt_history_ix = None;
+ let prompt = self.pending_prompt.as_str();
+ self.editor.update(cx, |editor, cx| {
+ editor.set_text(prompt, cx);
+ editor.move_to_end(&Default::default(), cx)
+ });
+ }
+ }
+ }
+
+ fn render_buttons(&self, cx: &mut ViewContext<Self>) -> Vec<AnyElement> {
+ let mode = match &self.mode {
+ PromptEditorMode::Buffer { codegen, .. } => {
+ let codegen = codegen.read(cx);
+ if codegen.is_insertion {
+ GenerationMode::Generate
+ } else {
+ GenerationMode::Transform
+ }
+ }
+ PromptEditorMode::Terminal { .. } => GenerationMode::Generate,
+ };
+
+ let codegen_status = self.codegen_status(cx);
+
+ match codegen_status {
+ CodegenStatus::Idle => {
+ vec![
+ IconButton::new("cancel", IconName::Close)
+ .icon_color(Color::Muted)
+ .shape(IconButtonShape::Square)
+ .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
+ .on_click(
+ cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)),
+ )
+ .into_any_element(),
+ Button::new("start", mode.start_label())
+ .icon(IconName::Return)
+ .icon_color(Color::Muted)
+ .on_click(
+ cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StartRequested)),
+ )
+ .into_any_element(),
+ ]
+ }
+ CodegenStatus::Pending => vec![
+ IconButton::new("cancel", IconName::Close)
+ .icon_color(Color::Muted)
+ .shape(IconButtonShape::Square)
+ .tooltip(|cx| Tooltip::text("Cancel Assist", cx))
+ .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)))
+ .into_any_element(),
+ IconButton::new("stop", IconName::Stop)
+ .icon_color(Color::Error)
+ .shape(IconButtonShape::Square)
+ .tooltip(move |cx| {
+ Tooltip::with_meta(
+ mode.tooltip_interrupt(),
+ Some(&menu::Cancel),
+ "Changes won't be discarded",
+ cx,
+ )
+ })
+ .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::StopRequested)))
+ .into_any_element(),
+ ],
+ CodegenStatus::Done | CodegenStatus::Error(_) => {
+ let cancel = IconButton::new("cancel", IconName::Close)
+ .icon_color(Color::Muted)
+ .shape(IconButtonShape::Square)
+ .tooltip(|cx| Tooltip::for_action("Cancel Assist", &menu::Cancel, cx))
+ .on_click(cx.listener(|_, _, cx| cx.emit(PromptEditorEvent::CancelRequested)))
+ .into_any_element();
+
+ let has_error = matches!(codegen_status, CodegenStatus::Error(_));
+ if has_error || self.edited_since_done {
+ vec![
+ cancel,
+ IconButton::new("restart", IconName::RotateCw)
+ .icon_color(Color::Info)
+ .shape(IconButtonShape::Square)
+ .tooltip(move |cx| {
+ Tooltip::with_meta(
+ mode.tooltip_restart(),
+ Some(&menu::Confirm),
+ "Changes will be discarded",
+ cx,
+ )
+ })
+ .on_click(cx.listener(|_, _, cx| {
+ cx.emit(PromptEditorEvent::StartRequested);
+ }))
+ .into_any_element(),
+ ]
+ } else {
+ let accept = IconButton::new("accept", IconName::Check)
+ .icon_color(Color::Info)
+ .shape(IconButtonShape::Square)
+ .tooltip(move |cx| {
+ Tooltip::for_action(mode.tooltip_accept(), &menu::Confirm, cx)
+ })
+ .on_click(cx.listener(|_, _, cx| {
+ cx.emit(PromptEditorEvent::ConfirmRequested { execute: false });
+ }))
+ .into_any_element();
+
+ match &self.mode {
+ PromptEditorMode::Terminal { .. } => vec![
+ accept,
+ cancel,
+ IconButton::new("confirm", IconName::Play)
+ .icon_color(Color::Info)
+ .shape(IconButtonShape::Square)
+ .tooltip(|cx| {
+ Tooltip::for_action(
+ "Execute Generated Command",
+ &menu::SecondaryConfirm,
+ cx,
+ )
+ })
+ .on_click(cx.listener(|_, _, cx| {
+ cx.emit(PromptEditorEvent::ConfirmRequested { execute: true });
+ }))
+ .into_any_element(),
+ ],
+ PromptEditorMode::Buffer { .. } => vec![accept, cancel],
+ }
+ }
+ }
+ }
+ }
+
+ fn cycle_prev(&mut self, _: &CyclePreviousInlineAssist, cx: &mut ViewContext<Self>) {
+ match &self.mode {
+ PromptEditorMode::Buffer { codegen, .. } => {
+ codegen.update(cx, |codegen, cx| codegen.cycle_prev(cx));
+ }
+ PromptEditorMode::Terminal { .. } => {
+ // no cycle buttons in terminal mode
+ }
+ }
+ }
+
+ fn cycle_next(&mut self, _: &CycleNextInlineAssist, cx: &mut ViewContext<Self>) {
+ match &self.mode {
+ PromptEditorMode::Buffer { codegen, .. } => {
+ codegen.update(cx, |codegen, cx| codegen.cycle_next(cx));
+ }
+ PromptEditorMode::Terminal { .. } => {
+ // no cycle buttons in terminal mode
+ }
+ }
+ }
+
+ fn render_cycle_controls(&self, codegen: &BufferCodegen, cx: &ViewContext<Self>) -> AnyElement {
+ let disabled = matches!(codegen.status(cx), CodegenStatus::Idle);
+
+ let model_registry = LanguageModelRegistry::read_global(cx);
+ let default_model = model_registry.active_model();
+ let alternative_models = model_registry.inline_alternative_models();
+
+ let get_model_name = |index: usize| -> String {
+ let name = |model: &Arc<dyn LanguageModel>| model.name().0.to_string();
+
+ match index {
+ 0 => default_model.as_ref().map_or_else(String::new, name),
+ index if index <= alternative_models.len() => alternative_models
+ .get(index - 1)
+ .map_or_else(String::new, name),
+ _ => String::new(),
+ }
+ };
+
+ let total_models = alternative_models.len() + 1;
+
+ if total_models <= 1 {
+ return div().into_any_element();
+ }
+
+ let current_index = codegen.active_alternative;
+ let prev_index = (current_index + total_models - 1) % total_models;
+ let next_index = (current_index + 1) % total_models;
+
+ let prev_model_name = get_model_name(prev_index);
+ let next_model_name = get_model_name(next_index);
+
+ h_flex()
+ .child(
+ IconButton::new("previous", IconName::ChevronLeft)
+ .icon_color(Color::Muted)
+ .disabled(disabled || current_index == 0)
+ .shape(IconButtonShape::Square)
+ .tooltip({
+ let focus_handle = self.editor.focus_handle(cx);
+ move |cx| {
+ cx.new_view(|cx| {
+ let mut tooltip = Tooltip::new("Previous Alternative").key_binding(
+ KeyBinding::for_action_in(
+ &CyclePreviousInlineAssist,
+ &focus_handle,
+ cx,
+ ),
+ );
+ if !disabled && current_index != 0 {
+ tooltip = tooltip.meta(prev_model_name.clone());
+ }
+ tooltip
+ })
+ .into()
+ }
+ })
+ .on_click(cx.listener(|this, _, cx| {
+ this.cycle_prev(&CyclePreviousInlineAssist, cx);
+ })),
+ )
+ .child(
+ Label::new(format!(
+ "{}/{}",
+ codegen.active_alternative + 1,
+ codegen.alternative_count(cx)
+ ))
+ .size(LabelSize::Small)
+ .color(if disabled {
+ Color::Disabled
+ } else {
+ Color::Muted
+ }),
+ )
+ .child(
+ IconButton::new("next", IconName::ChevronRight)
+ .icon_color(Color::Muted)
+ .disabled(disabled || current_index == total_models - 1)
+ .shape(IconButtonShape::Square)
+ .tooltip({
+ let focus_handle = self.editor.focus_handle(cx);
+ move |cx| {
+ cx.new_view(|cx| {
+ let mut tooltip = Tooltip::new("Next Alternative").key_binding(
+ KeyBinding::for_action_in(
+ &CycleNextInlineAssist,
+ &focus_handle,
+ cx,
+ ),
+ );
+ if !disabled && current_index != total_models - 1 {
+ tooltip = tooltip.meta(next_model_name.clone());
+ }
+ tooltip
+ })
+ .into()
+ }
+ })
+ .on_click(
+ cx.listener(|this, _, cx| this.cycle_next(&CycleNextInlineAssist, cx)),
+ ),
+ )
+ .into_any_element()
+ }
+
+ fn render_rate_limit_notice(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ Popover::new().child(
+ v_flex()
+ .occlude()
+ .p_2()
+ .child(
+ Label::new("Out of Tokens")
+ .size(LabelSize::Small)
+ .weight(FontWeight::BOLD),
+ )
+ .child(Label::new(
+ "Try Zed Pro for higher limits, a wider range of models, and more.",
+ ))
+ .child(
+ h_flex()
+ .justify_between()
+ .child(CheckboxWithLabel::new(
+ "dont-show-again",
+ Label::new("Don't show again"),
+ if dismissed_rate_limit_notice() {
+ ui::ToggleState::Selected
+ } else {
+ ui::ToggleState::Unselected
+ },
+ |selection, cx| {
+ let is_dismissed = match selection {
+ ui::ToggleState::Unselected => false,
+ ui::ToggleState::Indeterminate => return,
+ ui::ToggleState::Selected => true,
+ };
+
+ set_rate_limit_notice_dismissed(is_dismissed, cx)
+ },
+ ))
+ .child(
+ h_flex()
+ .gap_2()
+ .child(
+ Button::new("dismiss", "Dismiss")
+ .style(ButtonStyle::Transparent)
+ .on_click(cx.listener(Self::toggle_rate_limit_notice)),
+ )
+ .child(Button::new("more-info", "More Info").on_click(
+ |_event, cx| {
+ cx.dispatch_action(Box::new(
+ zed_actions::OpenAccountSettings,
+ ))
+ },
+ )),
+ ),
+ ),
+ )
+ }
+
+ fn render_editor(&mut self, cx: &mut ViewContext<Self>) -> AnyElement {
+ let font_size = TextSize::Default.rems(cx);
+ let line_height = font_size.to_pixels(cx.rem_size()) * 1.3;
+
+ v_flex()
+ .key_context("MessageEditor")
+ .size_full()
+ .gap_2()
+ .p_2()
+ .bg(cx.theme().colors().editor_background)
+ .child({
+ let settings = ThemeSettings::get_global(cx);
+ let text_style = TextStyle {
+ color: cx.theme().colors().editor_foreground,
+ font_family: settings.ui_font.family.clone(),
+ font_features: settings.ui_font.features.clone(),
+ font_size: font_size.into(),
+ font_weight: settings.ui_font.weight,
+ line_height: line_height.into(),
+ ..Default::default()
+ };
+
+ EditorElement::new(
+ &self.editor,
+ EditorStyle {
+ background: cx.theme().colors().editor_background,
+ local_player: cx.theme().players().local(),
+ text: text_style,
+ ..Default::default()
+ },
+ )
+ })
+ .into_any_element()
+ }
+}
+
+pub enum PromptEditorMode {
+ Buffer {
+ id: InlineAssistId,
+ codegen: Model<BufferCodegen>,
+ gutter_dimensions: Arc<Mutex<GutterDimensions>>,
+ },
+ Terminal {
+ id: TerminalInlineAssistId,
+ codegen: Model<TerminalCodegen>,
+ height_in_lines: u8,
+ },
+}
+
+pub enum PromptEditorEvent {
+ StartRequested,
+ StopRequested,
+ ConfirmRequested { execute: bool },
+ CancelRequested,
+ DismissRequested,
+ Resized { height_in_lines: u8 },
+}
+
+#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
+pub struct InlineAssistId(pub usize);
+
+impl InlineAssistId {
+ pub fn post_inc(&mut self) -> InlineAssistId {
+ let id = *self;
+ self.0 += 1;
+ id
+ }
+}
+
+impl PromptEditor<BufferCodegen> {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new_buffer(
+ id: InlineAssistId,
+ gutter_dimensions: Arc<Mutex<GutterDimensions>>,
+ prompt_history: VecDeque<String>,
+ prompt_buffer: Model<MultiBuffer>,
+ codegen: Model<BufferCodegen>,
+ fs: Arc<dyn Fs>,
+ context_store: Model<ContextStore>,
+ workspace: WeakView<Workspace>,
+ thread_store: Option<WeakModel<ThreadStore>>,
+ cx: &mut ViewContext<PromptEditor<BufferCodegen>>,
+ ) -> PromptEditor<BufferCodegen> {
+ let codegen_subscription = cx.observe(&codegen, Self::handle_codegen_changed);
+ let mode = PromptEditorMode::Buffer {
+ id,
+ codegen,
+ gutter_dimensions,
+ };
+
+ let prompt_editor = cx.new_view(|cx| {
+ let mut editor = Editor::new(
+ EditorMode::AutoHeight {
+ max_lines: Self::MAX_LINES as usize,
+ },
+ prompt_buffer,
+ None,
+ false,
+ cx,
+ );
+ editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
+ // Since the prompt editors for all inline assistants are linked,
+ // always show the cursor (even when it isn't focused) because
+ // typing in one will make what you typed appear in all of them.
+ editor.set_show_cursor_when_unfocused(true, cx);
+ editor.set_placeholder_text(Self::placeholder_text(&mode, cx), cx);
+ editor
+ });
+ let context_picker_menu_handle = PopoverMenuHandle::default();
+
+ let mut this: PromptEditor<BufferCodegen> = PromptEditor {
+ editor: prompt_editor.clone(),
+ context_strip: cx.new_view(|cx| {
+ ContextStrip::new(
+ context_store,
+ workspace.clone(),
+ thread_store.clone(),
+ prompt_editor.focus_handle(cx),
+ context_picker_menu_handle.clone(),
+ cx,
+ )
+ }),
+ context_picker_menu_handle,
+ language_model_selector: cx.new_view(|cx| {
+ let fs = fs.clone();
+ LanguageModelSelector::new(
+ move |model, cx| {
+ update_settings_file::<AssistantSettings>(
+ fs.clone(),
+ cx,
+ move |settings, _| settings.set_model(model.clone()),
+ );
+ },
+ cx,
+ )
+ }),
+ edited_since_done: false,
+ prompt_history,
+ prompt_history_ix: None,
+ pending_prompt: String::new(),
+ _codegen_subscription: codegen_subscription,
+ editor_subscriptions: Vec::new(),
+ show_rate_limit_notice: false,
+ mode,
+ _phantom: Default::default(),
+ };
+
+ this.subscribe_to_editor(cx);
+ this
+ }
+
+ fn handle_codegen_changed(
+ &mut self,
+ _: Model<BufferCodegen>,
+ cx: &mut ViewContext<PromptEditor<BufferCodegen>>,
+ ) {
+ match self.codegen_status(cx) {
+ CodegenStatus::Idle => {
+ self.editor
+ .update(cx, |editor, _| editor.set_read_only(false));
+ }
+ CodegenStatus::Pending => {
+ self.editor
+ .update(cx, |editor, _| editor.set_read_only(true));
+ }
+ CodegenStatus::Done => {
+ self.edited_since_done = false;
+ self.editor
+ .update(cx, |editor, _| editor.set_read_only(false));
+ }
+ CodegenStatus::Error(error) => {
+ if cx.has_flag::<ZedPro>()
+ && error.error_code() == proto::ErrorCode::RateLimitExceeded
+ && !dismissed_rate_limit_notice()
+ {
+ self.show_rate_limit_notice = true;
+ cx.notify();
+ }
+
+ self.edited_since_done = false;
+ self.editor
+ .update(cx, |editor, _| editor.set_read_only(false));
+ }
+ }
+ }
+
+ pub fn id(&self) -> InlineAssistId {
+ match &self.mode {
+ PromptEditorMode::Buffer { id, .. } => *id,
+ PromptEditorMode::Terminal { .. } => unreachable!(),
+ }
+ }
+
+ pub fn codegen(&self) -> &Model<BufferCodegen> {
+ match &self.mode {
+ PromptEditorMode::Buffer { codegen, .. } => codegen,
+ PromptEditorMode::Terminal { .. } => unreachable!(),
+ }
+ }
+
+ pub fn gutter_dimensions(&self) -> &Arc<Mutex<GutterDimensions>> {
+ match &self.mode {
+ PromptEditorMode::Buffer {
+ gutter_dimensions, ..
+ } => gutter_dimensions,
+ PromptEditorMode::Terminal { .. } => unreachable!(),
+ }
+ }
+}
+
+#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
+pub struct TerminalInlineAssistId(pub usize);
+
+impl TerminalInlineAssistId {
+ pub fn post_inc(&mut self) -> TerminalInlineAssistId {
+ let id = *self;
+ self.0 += 1;
+ id
+ }
+}
+
+impl PromptEditor<TerminalCodegen> {
+ #[allow(clippy::too_many_arguments)]
+ pub fn new_terminal(
+ id: TerminalInlineAssistId,
+ prompt_history: VecDeque<String>,
+ prompt_buffer: Model<MultiBuffer>,
+ codegen: Model<TerminalCodegen>,
+ fs: Arc<dyn Fs>,
+ context_store: Model<ContextStore>,
+ workspace: WeakView<Workspace>,
+ thread_store: Option<WeakModel<ThreadStore>>,
+ cx: &mut ViewContext<Self>,
+ ) -> Self {
+ let codegen_subscription = cx.observe(&codegen, Self::handle_codegen_changed);
+ let mode = PromptEditorMode::Terminal {
+ id,
+ codegen,
+ height_in_lines: 1,
+ };
+
+ let prompt_editor = cx.new_view(|cx| {
+ let mut editor = Editor::new(
+ EditorMode::AutoHeight {
+ max_lines: Self::MAX_LINES as usize,
+ },
+ prompt_buffer,
+ None,
+ false,
+ cx,
+ );
+ editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
+ editor.set_placeholder_text(Self::placeholder_text(&mode, cx), cx);
+ editor
+ });
+ let context_picker_menu_handle = PopoverMenuHandle::default();
+
+ let mut this = Self {
+ editor: prompt_editor.clone(),
+ context_strip: cx.new_view(|cx| {
+ ContextStrip::new(
+ context_store,
+ workspace.clone(),
+ thread_store.clone(),
+ prompt_editor.focus_handle(cx),
+ context_picker_menu_handle.clone(),
+ cx,
+ )
+ }),
+ context_picker_menu_handle,
+ language_model_selector: cx.new_view(|cx| {
+ let fs = fs.clone();
+ LanguageModelSelector::new(
+ move |model, cx| {
+ update_settings_file::<AssistantSettings>(
+ fs.clone(),
+ cx,
+ move |settings, _| settings.set_model(model.clone()),
+ );
+ },
+ cx,
+ )
+ }),
+ edited_since_done: false,
+ prompt_history,
+ prompt_history_ix: None,
+ pending_prompt: String::new(),
+ _codegen_subscription: codegen_subscription,
+ editor_subscriptions: Vec::new(),
+ mode,
+ show_rate_limit_notice: false,
+ _phantom: Default::default(),
+ };
+ this.count_lines(cx);
+ this.subscribe_to_editor(cx);
+ this
+ }
+
+ fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
+ let height_in_lines = cmp::max(
+ 2, // Make the editor at least two lines tall, to account for padding and buttons.
+ cmp::min(
+ self.editor
+ .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
+ Self::MAX_LINES as u32,
+ ),
+ ) as u8;
+
+ match &mut self.mode {
+ PromptEditorMode::Terminal {
+ height_in_lines: current_height,
+ ..
+ } => {
+ if height_in_lines != *current_height {
+ *current_height = height_in_lines;
+ cx.emit(PromptEditorEvent::Resized { height_in_lines });
+ }
+ }
+ PromptEditorMode::Buffer { .. } => unreachable!(),
+ }
+ }
+
+ fn handle_codegen_changed(&mut self, _: Model<TerminalCodegen>, cx: &mut ViewContext<Self>) {
+ match &self.codegen().read(cx).status {
+ CodegenStatus::Idle => {
+ self.editor
+ .update(cx, |editor, _| editor.set_read_only(false));
+ }
+ CodegenStatus::Pending => {
+ self.editor
+ .update(cx, |editor, _| editor.set_read_only(true));
+ }
+ CodegenStatus::Done | CodegenStatus::Error(_) => {
+ self.edited_since_done = false;
+ self.editor
+ .update(cx, |editor, _| editor.set_read_only(false));
+ }
+ }
+ }
+
+ pub fn codegen(&self) -> &Model<TerminalCodegen> {
+ match &self.mode {
+ PromptEditorMode::Buffer { .. } => unreachable!(),
+ PromptEditorMode::Terminal { codegen, .. } => codegen,
+ }
+ }
+
+ pub fn id(&self) -> TerminalInlineAssistId {
+ match &self.mode {
+ PromptEditorMode::Buffer { .. } => unreachable!(),
+ PromptEditorMode::Terminal { id, .. } => *id,
+ }
+ }
+}
+
+const DISMISSED_RATE_LIMIT_NOTICE_KEY: &str = "dismissed-rate-limit-notice";
+
+fn dismissed_rate_limit_notice() -> bool {
+ db::kvp::KEY_VALUE_STORE
+ .read_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY)
+ .log_err()
+ .map_or(false, |s| s.is_some())
+}
+
+fn set_rate_limit_notice_dismissed(is_dismissed: bool, cx: &mut AppContext) {
+ db::write_and_log(cx, move || async move {
+ if is_dismissed {
+ db::kvp::KEY_VALUE_STORE
+ .write_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into(), "1".into())
+ .await
+ } else {
+ db::kvp::KEY_VALUE_STORE
+ .delete_kvp(DISMISSED_RATE_LIMIT_NOTICE_KEY.into())
+ .await
+ }
+ })
+}
pub enum CodegenStatus {
Idle,
@@ -0,0 +1,192 @@
+use crate::inline_prompt_editor::CodegenStatus;
+use client::telemetry::Telemetry;
+use futures::{channel::mpsc, SinkExt, StreamExt};
+use gpui::{AppContext, EventEmitter, Model, ModelContext, Task};
+use language_model::{LanguageModelRegistry, LanguageModelRequest};
+use language_models::report_assistant_event;
+use std::{sync::Arc, time::Instant};
+use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
+use terminal::Terminal;
+
+pub struct TerminalCodegen {
+ pub status: CodegenStatus,
+ pub telemetry: Option<Arc<Telemetry>>,
+ terminal: Model<Terminal>,
+ generation: Task<()>,
+ pub message_id: Option<String>,
+ transaction: Option<TerminalTransaction>,
+}
+
+impl EventEmitter<CodegenEvent> for TerminalCodegen {}
+
+impl TerminalCodegen {
+ pub fn new(terminal: Model<Terminal>, telemetry: Option<Arc<Telemetry>>) -> Self {
+ Self {
+ terminal,
+ telemetry,
+ status: CodegenStatus::Idle,
+ generation: Task::ready(()),
+ message_id: None,
+ transaction: None,
+ }
+ }
+
+ pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
+ let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
+ return;
+ };
+
+ let model_api_key = model.api_key(cx);
+ let http_client = cx.http_client();
+ let telemetry = self.telemetry.clone();
+ self.status = CodegenStatus::Pending;
+ self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
+ self.generation = cx.spawn(|this, mut cx| async move {
+ let model_telemetry_id = model.telemetry_id();
+ let model_provider_id = model.provider_id();
+ let response = model.stream_completion_text(prompt, &cx).await;
+ let generate = async {
+ let message_id = response
+ .as_ref()
+ .ok()
+ .and_then(|response| response.message_id.clone());
+
+ let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
+
+ let task = cx.background_executor().spawn({
+ let message_id = message_id.clone();
+ let executor = cx.background_executor().clone();
+ async move {
+ let mut response_latency = None;
+ let request_start = Instant::now();
+ let task = async {
+ let mut chunks = response?.stream;
+ while let Some(chunk) = chunks.next().await {
+ if response_latency.is_none() {
+ response_latency = Some(request_start.elapsed());
+ }
+ let chunk = chunk?;
+ hunks_tx.send(chunk).await?;
+ }
+
+ anyhow::Ok(())
+ };
+
+ let result = task.await;
+
+ let error_message = result.as_ref().err().map(|error| error.to_string());
+ report_assistant_event(
+ AssistantEvent {
+ conversation_id: None,
+ kind: AssistantKind::InlineTerminal,
+ message_id,
+ phase: AssistantPhase::Response,
+ model: model_telemetry_id,
+ model_provider: model_provider_id.to_string(),
+ response_latency,
+ error_message,
+ language_name: None,
+ },
+ telemetry,
+ http_client,
+ model_api_key,
+ &executor,
+ );
+
+ result?;
+ anyhow::Ok(())
+ }
+ });
+
+ this.update(&mut cx, |this, _| {
+ this.message_id = message_id;
+ })?;
+
+ while let Some(hunk) = hunks_rx.next().await {
+ this.update(&mut cx, |this, cx| {
+ if let Some(transaction) = &mut this.transaction {
+ transaction.push(hunk, cx);
+ cx.notify();
+ }
+ })?;
+ }
+
+ task.await?;
+ anyhow::Ok(())
+ };
+
+ let result = generate.await;
+
+ this.update(&mut cx, |this, cx| {
+ if let Err(error) = result {
+ this.status = CodegenStatus::Error(error);
+ } else {
+ this.status = CodegenStatus::Done;
+ }
+ cx.emit(CodegenEvent::Finished);
+ cx.notify();
+ })
+ .ok();
+ });
+ cx.notify();
+ }
+
+ pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
+ self.status = CodegenStatus::Done;
+ self.generation = Task::ready(());
+ cx.emit(CodegenEvent::Finished);
+ cx.notify();
+ }
+
+ pub fn complete(&mut self, cx: &mut ModelContext<Self>) {
+ if let Some(transaction) = self.transaction.take() {
+ transaction.complete(cx);
+ }
+ }
+
+ pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
+ if let Some(transaction) = self.transaction.take() {
+ transaction.undo(cx);
+ }
+ }
+}
+
+#[derive(Copy, Clone, Debug)]
+pub enum CodegenEvent {
+ Finished,
+}
+
+pub const CLEAR_INPUT: &str = "\x15";
+const CARRIAGE_RETURN: &str = "\x0d";
+
+struct TerminalTransaction {
+ terminal: Model<Terminal>,
+}
+
+impl TerminalTransaction {
+ pub fn start(terminal: Model<Terminal>) -> Self {
+ Self { terminal }
+ }
+
+ pub fn push(&mut self, hunk: String, cx: &mut AppContext) {
+ // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
+ let input = Self::sanitize_input(hunk);
+ self.terminal
+ .update(cx, |terminal, _| terminal.input(input));
+ }
+
+ pub fn undo(&self, cx: &mut AppContext) {
+ self.terminal
+ .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
+ }
+
+ pub fn complete(&self, cx: &mut AppContext) {
+ self.terminal.update(cx, |terminal, _| {
+ terminal.input(CARRIAGE_RETURN.to_string())
+ });
+ }
+
+ fn sanitize_input(input: String) -> String {
+ input.replace(['\r', '\n'], "")
+ }
+}
@@ -1,38 +1,29 @@
use crate::context::attach_context_to_message;
-use crate::context_picker::ContextPicker;
use crate::context_store::ContextStore;
-use crate::context_strip::ContextStrip;
-use crate::inline_prompt_editor::{CodegenStatus, PromptEditorEvent, PromptMode};
+use crate::inline_prompt_editor::{
+ CodegenStatus, PromptEditor, PromptEditorEvent, TerminalInlineAssistId,
+};
use crate::prompts::PromptBuilder;
+use crate::terminal_codegen::{CodegenEvent, TerminalCodegen, CLEAR_INPUT};
use crate::thread_store::ThreadStore;
-use crate::ToggleContextPicker;
-use crate::{assistant_settings::AssistantSettings, inline_prompt_editor::render_cancel_button};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use collections::{HashMap, VecDeque};
-use editor::{
- actions::{MoveDown, MoveUp, SelectAll},
- Editor, EditorElement, EditorEvent, EditorMode, EditorStyle, MultiBuffer,
-};
+use editor::{actions::SelectAll, MultiBuffer};
use fs::Fs;
-use futures::{channel::mpsc, SinkExt, StreamExt};
use gpui::{
- AppContext, Context, EventEmitter, FocusHandle, FocusableView, Global, Model, ModelContext,
- Subscription, Task, TextStyle, UpdateGlobal, View, WeakModel, WeakView,
+ AppContext, Context, FocusableView, Global, Model, Subscription, UpdateGlobal, View, WeakModel,
+ WeakView,
};
use language::Buffer;
use language_model::{
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
-use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
use language_models::report_assistant_event;
-use settings::{update_settings_file, Settings};
-use std::{cmp, sync::Arc, time::Instant};
+use std::sync::Arc;
use telemetry_events::{AssistantEvent, AssistantKind, AssistantPhase};
-use terminal::Terminal;
use terminal_view::TerminalView;
-use theme::ThemeSettings;
-use ui::{prelude::*, text_for_action, IconButtonShape, PopoverMenuHandle, Tooltip};
+use ui::prelude::*;
use util::ResultExt;
use workspace::{notifications::NotificationId, Toast, Workspace};
@@ -48,17 +39,6 @@ pub fn init(
const DEFAULT_CONTEXT_LINES: usize = 50;
const PROMPT_HISTORY_MAX_LEN: usize = 20;
-#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
-struct TerminalInlineAssistId(usize);
-
-impl TerminalInlineAssistId {
- fn post_inc(&mut self) -> TerminalInlineAssistId {
- let id = *self;
- self.0 += 1;
- id
- }
-}
-
pub struct TerminalInlineAssistant {
next_assist_id: TerminalInlineAssistId,
assists: HashMap<TerminalInlineAssistId, TerminalInlineAssist>,
@@ -99,10 +79,10 @@ impl TerminalInlineAssistant {
MultiBuffer::singleton(cx.new_model(|cx| Buffer::local(String::new(), cx)), cx)
});
let context_store = cx.new_model(|_cx| ContextStore::new());
- let codegen = cx.new_model(|_| Codegen::new(terminal, self.telemetry.clone()));
+ let codegen = cx.new_model(|_| TerminalCodegen::new(terminal, self.telemetry.clone()));
let prompt_editor = cx.new_view(|cx| {
- PromptEditor::new(
+ PromptEditor::new_terminal(
assist_id,
self.prompt_history.clone(),
prompt_buffer.clone(),
@@ -151,11 +131,11 @@ impl TerminalInlineAssistant {
fn handle_prompt_editor_event(
&mut self,
- prompt_editor: View<PromptEditor>,
+ prompt_editor: View<PromptEditor<TerminalCodegen>>,
event: &PromptEditorEvent,
cx: &mut WindowContext,
) {
- let assist_id = prompt_editor.read(cx).id;
+ let assist_id = prompt_editor.read(cx).id();
match event {
PromptEditorEvent::StartRequested => {
self.start_assist(assist_id, cx);
@@ -381,8 +361,8 @@ impl TerminalInlineAssistant {
struct TerminalInlineAssist {
terminal: WeakView<TerminalView>,
- prompt_editor: Option<View<PromptEditor>>,
- codegen: Model<Codegen>,
+ prompt_editor: Option<View<PromptEditor<TerminalCodegen>>>,
+ codegen: Model<TerminalCodegen>,
workspace: WeakView<Workspace>,
context_store: Model<ContextStore>,
_subscriptions: Vec<Subscription>,
@@ -392,12 +372,12 @@ impl TerminalInlineAssist {
pub fn new(
assist_id: TerminalInlineAssistId,
terminal: &View<TerminalView>,
- prompt_editor: View<PromptEditor>,
+ prompt_editor: View<PromptEditor<TerminalCodegen>>,
workspace: WeakView<Workspace>,
context_store: Model<ContextStore>,
cx: &mut WindowContext,
) -> Self {
- let codegen = prompt_editor.read(cx).codegen.clone();
+ let codegen = prompt_editor.read(cx).codegen().clone();
Self {
terminal: terminal.downgrade(),
prompt_editor: Some(prompt_editor.clone()),
@@ -448,556 +428,3 @@ impl TerminalInlineAssist {
}
}
}
-
-struct PromptEditor {
- id: TerminalInlineAssistId,
- height_in_lines: u8,
- editor: View<Editor>,
- context_strip: View<ContextStrip>,
- context_picker_menu_handle: PopoverMenuHandle<ContextPicker>,
- language_model_selector: View<LanguageModelSelector>,
- edited_since_done: bool,
- prompt_history: VecDeque<String>,
- prompt_history_ix: Option<usize>,
- pending_prompt: String,
- codegen: Model<Codegen>,
- _codegen_subscription: Subscription,
- editor_subscriptions: Vec<Subscription>,
-}
-
-impl EventEmitter<PromptEditorEvent> for PromptEditor {}
-
-impl Render for PromptEditor {
- fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- let mut buttons = Vec::new();
-
- buttons.extend(render_cancel_button(
- (&self.codegen.read(cx).status).into(),
- self.edited_since_done,
- PromptMode::Generate {
- supports_execute: true,
- },
- cx,
- ));
-
- v_flex()
- .border_y_1()
- .border_color(cx.theme().status().info_border)
- .py_2()
- .size_full()
- .child(
- h_flex()
- .key_context("PromptEditor")
- .bg(cx.theme().colors().editor_background)
- .on_action(cx.listener(Self::toggle_context_picker))
- .on_action(cx.listener(Self::confirm))
- .on_action(cx.listener(Self::secondary_confirm))
- .on_action(cx.listener(Self::cancel))
- .on_action(cx.listener(Self::move_up))
- .on_action(cx.listener(Self::move_down))
- .child(
- h_flex()
- .w_12()
- .justify_center()
- .gap_2()
- .child(LanguageModelSelectorPopoverMenu::new(
- self.language_model_selector.clone(),
- IconButton::new("context", IconName::SettingsAlt)
- .shape(IconButtonShape::Square)
- .icon_size(IconSize::Small)
- .icon_color(Color::Muted)
- .tooltip(move |cx| {
- Tooltip::with_meta(
- format!(
- "Using {}",
- LanguageModelRegistry::read_global(cx)
- .active_model()
- .map(|model| model.name().0)
- .unwrap_or_else(|| "No model selected".into()),
- ),
- None,
- "Change Model",
- cx,
- )
- }),
- ))
- .children(
- if let CodegenStatus::Error(error) = &self.codegen.read(cx).status {
- let error_message = SharedString::from(error.to_string());
- Some(
- div()
- .id("error")
- .tooltip(move |cx| {
- Tooltip::text(error_message.clone(), cx)
- })
- .child(
- Icon::new(IconName::XCircle)
- .size(IconSize::Small)
- .color(Color::Error),
- ),
- )
- } else {
- None
- },
- ),
- )
- .child(div().flex_1().child(self.render_prompt_editor(cx)))
- .child(h_flex().gap_1().pr_4().children(buttons)),
- )
- .child(h_flex().child(self.context_strip.clone()))
- }
-}
-
-impl FocusableView for PromptEditor {
- fn focus_handle(&self, cx: &AppContext) -> FocusHandle {
- self.editor.focus_handle(cx)
- }
-}
-
-impl PromptEditor {
- const MAX_LINES: u8 = 8;
-
- #[allow(clippy::too_many_arguments)]
- fn new(
- id: TerminalInlineAssistId,
- prompt_history: VecDeque<String>,
- prompt_buffer: Model<MultiBuffer>,
- codegen: Model<Codegen>,
- fs: Arc<dyn Fs>,
- context_store: Model<ContextStore>,
- workspace: WeakView<Workspace>,
- thread_store: Option<WeakModel<ThreadStore>>,
- cx: &mut ViewContext<Self>,
- ) -> Self {
- let prompt_editor = cx.new_view(|cx| {
- let mut editor = Editor::new(
- EditorMode::AutoHeight {
- max_lines: Self::MAX_LINES as usize,
- },
- prompt_buffer,
- None,
- false,
- cx,
- );
- editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
- editor.set_placeholder_text(Self::placeholder_text(cx), cx);
- editor
- });
- let context_picker_menu_handle = PopoverMenuHandle::default();
-
- let mut this = Self {
- id,
- height_in_lines: 1,
- editor: prompt_editor.clone(),
- context_strip: cx.new_view(|cx| {
- ContextStrip::new(
- context_store,
- workspace.clone(),
- thread_store.clone(),
- prompt_editor.focus_handle(cx),
- context_picker_menu_handle.clone(),
- cx,
- )
- }),
- context_picker_menu_handle,
- language_model_selector: cx.new_view(|cx| {
- let fs = fs.clone();
- LanguageModelSelector::new(
- move |model, cx| {
- update_settings_file::<AssistantSettings>(
- fs.clone(),
- cx,
- move |settings, _| settings.set_model(model.clone()),
- );
- },
- cx,
- )
- }),
- edited_since_done: false,
- prompt_history,
- prompt_history_ix: None,
- pending_prompt: String::new(),
- _codegen_subscription: cx.observe(&codegen, Self::handle_codegen_changed),
- editor_subscriptions: Vec::new(),
- codegen,
- };
- this.count_lines(cx);
- this.subscribe_to_editor(cx);
- this
- }
-
- fn placeholder_text(cx: &WindowContext) -> String {
- let context_keybinding = text_for_action(&crate::ToggleFocus, cx)
- .map(|keybinding| format!(" β’ {keybinding} for context"))
- .unwrap_or_default();
-
- format!("Generateβ¦{context_keybinding} ββ for history")
- }
-
- fn subscribe_to_editor(&mut self, cx: &mut ViewContext<Self>) {
- self.editor_subscriptions.clear();
- self.editor_subscriptions
- .push(cx.observe(&self.editor, Self::handle_prompt_editor_changed));
- self.editor_subscriptions
- .push(cx.subscribe(&self.editor, Self::handle_prompt_editor_events));
- }
-
- fn prompt(&self, cx: &AppContext) -> String {
- self.editor.read(cx).text(cx)
- }
-
- fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
- let height_in_lines = cmp::max(
- 2, // Make the editor at least two lines tall, to account for padding and buttons.
- cmp::min(
- self.editor
- .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
- Self::MAX_LINES as u32,
- ),
- ) as u8;
-
- if height_in_lines != self.height_in_lines {
- self.height_in_lines = height_in_lines;
- cx.emit(PromptEditorEvent::Resized { height_in_lines });
- }
- }
-
- fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
- self.count_lines(cx);
- }
-
- fn handle_prompt_editor_events(
- &mut self,
- _: View<Editor>,
- event: &EditorEvent,
- cx: &mut ViewContext<Self>,
- ) {
- match event {
- EditorEvent::Edited { .. } => {
- let prompt = self.editor.read(cx).text(cx);
- if self
- .prompt_history_ix
- .map_or(true, |ix| self.prompt_history[ix] != prompt)
- {
- self.prompt_history_ix.take();
- self.pending_prompt = prompt;
- }
-
- self.edited_since_done = true;
- cx.notify();
- }
- _ => {}
- }
- }
-
- fn handle_codegen_changed(&mut self, _: Model<Codegen>, cx: &mut ViewContext<Self>) {
- match &self.codegen.read(cx).status {
- CodegenStatus::Idle => {
- self.editor
- .update(cx, |editor, _| editor.set_read_only(false));
- }
- CodegenStatus::Pending => {
- self.editor
- .update(cx, |editor, _| editor.set_read_only(true));
- }
- CodegenStatus::Done | CodegenStatus::Error(_) => {
- self.edited_since_done = false;
- self.editor
- .update(cx, |editor, _| editor.set_read_only(false));
- }
- }
- }
-
- fn toggle_context_picker(&mut self, _: &ToggleContextPicker, cx: &mut ViewContext<Self>) {
- self.context_picker_menu_handle.toggle(cx);
- }
-
- fn cancel(&mut self, _: &editor::actions::Cancel, cx: &mut ViewContext<Self>) {
- match &self.codegen.read(cx).status {
- CodegenStatus::Idle | CodegenStatus::Done | CodegenStatus::Error(_) => {
- cx.emit(PromptEditorEvent::CancelRequested);
- }
- CodegenStatus::Pending => {
- cx.emit(PromptEditorEvent::StopRequested);
- }
- }
- }
-
- fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext<Self>) {
- match &self.codegen.read(cx).status {
- CodegenStatus::Idle => {
- if !self.editor.read(cx).text(cx).trim().is_empty() {
- cx.emit(PromptEditorEvent::StartRequested);
- }
- }
- CodegenStatus::Pending => {
- cx.emit(PromptEditorEvent::DismissRequested);
- }
- CodegenStatus::Done => {
- if self.edited_since_done {
- cx.emit(PromptEditorEvent::StartRequested);
- } else {
- cx.emit(PromptEditorEvent::ConfirmRequested { execute: false });
- }
- }
- CodegenStatus::Error(_) => {
- cx.emit(PromptEditorEvent::StartRequested);
- }
- }
- }
-
- fn secondary_confirm(&mut self, _: &menu::SecondaryConfirm, cx: &mut ViewContext<Self>) {
- if matches!(self.codegen.read(cx).status, CodegenStatus::Done) {
- cx.emit(PromptEditorEvent::ConfirmRequested { execute: true });
- }
- }
-
- fn move_up(&mut self, _: &MoveUp, cx: &mut ViewContext<Self>) {
- if let Some(ix) = self.prompt_history_ix {
- if ix > 0 {
- self.prompt_history_ix = Some(ix - 1);
- let prompt = self.prompt_history[ix - 1].as_str();
- self.editor.update(cx, |editor, cx| {
- editor.set_text(prompt, cx);
- editor.move_to_beginning(&Default::default(), cx);
- });
- }
- } else if !self.prompt_history.is_empty() {
- self.prompt_history_ix = Some(self.prompt_history.len() - 1);
- let prompt = self.prompt_history[self.prompt_history.len() - 1].as_str();
- self.editor.update(cx, |editor, cx| {
- editor.set_text(prompt, cx);
- editor.move_to_beginning(&Default::default(), cx);
- });
- }
- }
-
- fn move_down(&mut self, _: &MoveDown, cx: &mut ViewContext<Self>) {
- if let Some(ix) = self.prompt_history_ix {
- if ix < self.prompt_history.len() - 1 {
- self.prompt_history_ix = Some(ix + 1);
- let prompt = self.prompt_history[ix + 1].as_str();
- self.editor.update(cx, |editor, cx| {
- editor.set_text(prompt, cx);
- editor.move_to_end(&Default::default(), cx)
- });
- } else {
- self.prompt_history_ix = None;
- let prompt = self.pending_prompt.as_str();
- self.editor.update(cx, |editor, cx| {
- editor.set_text(prompt, cx);
- editor.move_to_end(&Default::default(), cx)
- });
- }
- }
- }
-
- fn render_prompt_editor(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
- let settings = ThemeSettings::get_global(cx);
- let text_style = TextStyle {
- color: if self.editor.read(cx).read_only(cx) {
- cx.theme().colors().text_disabled
- } else {
- cx.theme().colors().text
- },
- font_family: settings.buffer_font.family.clone(),
- font_fallbacks: settings.buffer_font.fallbacks.clone(),
- font_size: settings.buffer_font_size.into(),
- font_weight: settings.buffer_font.weight,
- line_height: relative(settings.buffer_line_height.value()),
- ..Default::default()
- };
- EditorElement::new(
- &self.editor,
- EditorStyle {
- background: cx.theme().colors().editor_background,
- local_player: cx.theme().players().local(),
- text: text_style,
- ..Default::default()
- },
- )
- }
-}
-
-#[derive(Debug)]
-pub enum CodegenEvent {
- Finished,
-}
-
-impl EventEmitter<CodegenEvent> for Codegen {}
-
-const CLEAR_INPUT: &str = "\x15";
-const CARRIAGE_RETURN: &str = "\x0d";
-
-struct TerminalTransaction {
- terminal: Model<Terminal>,
-}
-
-impl TerminalTransaction {
- pub fn start(terminal: Model<Terminal>) -> Self {
- Self { terminal }
- }
-
- pub fn push(&mut self, hunk: String, cx: &mut AppContext) {
- // Ensure that the assistant cannot accidentally execute commands that are streamed into the terminal
- let input = Self::sanitize_input(hunk);
- self.terminal
- .update(cx, |terminal, _| terminal.input(input));
- }
-
- pub fn undo(&self, cx: &mut AppContext) {
- self.terminal
- .update(cx, |terminal, _| terminal.input(CLEAR_INPUT.to_string()));
- }
-
- pub fn complete(&self, cx: &mut AppContext) {
- self.terminal.update(cx, |terminal, _| {
- terminal.input(CARRIAGE_RETURN.to_string())
- });
- }
-
- fn sanitize_input(input: String) -> String {
- input.replace(['\r', '\n'], "")
- }
-}
-
-pub struct Codegen {
- status: CodegenStatus,
- telemetry: Option<Arc<Telemetry>>,
- terminal: Model<Terminal>,
- generation: Task<()>,
- message_id: Option<String>,
- transaction: Option<TerminalTransaction>,
-}
-
-impl Codegen {
- pub fn new(terminal: Model<Terminal>, telemetry: Option<Arc<Telemetry>>) -> Self {
- Self {
- terminal,
- telemetry,
- status: CodegenStatus::Idle,
- generation: Task::ready(()),
- message_id: None,
- transaction: None,
- }
- }
-
- pub fn start(&mut self, prompt: LanguageModelRequest, cx: &mut ModelContext<Self>) {
- let Some(model) = LanguageModelRegistry::read_global(cx).active_model() else {
- return;
- };
-
- let model_api_key = model.api_key(cx);
- let http_client = cx.http_client();
- let telemetry = self.telemetry.clone();
- self.status = CodegenStatus::Pending;
- self.transaction = Some(TerminalTransaction::start(self.terminal.clone()));
- self.generation = cx.spawn(|this, mut cx| async move {
- let model_telemetry_id = model.telemetry_id();
- let model_provider_id = model.provider_id();
- let response = model.stream_completion_text(prompt, &cx).await;
- let generate = async {
- let message_id = response
- .as_ref()
- .ok()
- .and_then(|response| response.message_id.clone());
-
- let (mut hunks_tx, mut hunks_rx) = mpsc::channel(1);
-
- let task = cx.background_executor().spawn({
- let message_id = message_id.clone();
- let executor = cx.background_executor().clone();
- async move {
- let mut response_latency = None;
- let request_start = Instant::now();
- let task = async {
- let mut chunks = response?.stream;
- while let Some(chunk) = chunks.next().await {
- if response_latency.is_none() {
- response_latency = Some(request_start.elapsed());
- }
- let chunk = chunk?;
- hunks_tx.send(chunk).await?;
- }
-
- anyhow::Ok(())
- };
-
- let result = task.await;
-
- let error_message = result.as_ref().err().map(|error| error.to_string());
- report_assistant_event(
- AssistantEvent {
- conversation_id: None,
- kind: AssistantKind::InlineTerminal,
- message_id,
- phase: AssistantPhase::Response,
- model: model_telemetry_id,
- model_provider: model_provider_id.to_string(),
- response_latency,
- error_message,
- language_name: None,
- },
- telemetry,
- http_client,
- model_api_key,
- &executor,
- );
-
- result?;
- anyhow::Ok(())
- }
- });
-
- this.update(&mut cx, |this, _| {
- this.message_id = message_id;
- })?;
-
- while let Some(hunk) = hunks_rx.next().await {
- this.update(&mut cx, |this, cx| {
- if let Some(transaction) = &mut this.transaction {
- transaction.push(hunk, cx);
- cx.notify();
- }
- })?;
- }
-
- task.await?;
- anyhow::Ok(())
- };
-
- let result = generate.await;
-
- this.update(&mut cx, |this, cx| {
- if let Err(error) = result {
- this.status = CodegenStatus::Error(error);
- } else {
- this.status = CodegenStatus::Done;
- }
- cx.emit(CodegenEvent::Finished);
- cx.notify();
- })
- .ok();
- });
- cx.notify();
- }
-
- pub fn stop(&mut self, cx: &mut ModelContext<Self>) {
- self.status = CodegenStatus::Done;
- self.generation = Task::ready(());
- cx.emit(CodegenEvent::Finished);
- cx.notify();
- }
-
- pub fn complete(&mut self, cx: &mut ModelContext<Self>) {
- if let Some(transaction) = self.transaction.take() {
- transaction.complete(cx);
- }
- }
-
- pub fn undo(&mut self, cx: &mut ModelContext<Self>) {
- if let Some(transaction) = self.transaction.take() {
- transaction.undo(cx);
- }
- }
-}