From 9822d9673c00b914167a85ef0d7756a331af70f2 Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Wed, 26 Feb 2025 19:16:44 -0500 Subject: [PATCH] assistant2: Add `Thread::send_to_model` method (#25703) This PR adds a new `send_to_model` method to the `Thread` to encapsulate more of the thread-specific capabilities. We then call this in `MessageEditor::send_to_model`. Release Notes: - N/A --- crates/assistant2/src/message_editor.rs | 19 ++------------- crates/assistant2/src/thread.rs | 31 ++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 811b4b7728a06d22bb64bfa87b387e9c0a3e730e..2c749ba7f046fffaf35eef2b0df1b9f565e6d0f7 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -7,7 +7,7 @@ use gpui::{ pulsating_between, Animation, AnimationExt, App, DismissEvent, Entity, Focusable, Subscription, TextStyle, WeakEntity, }; -use language_model::{LanguageModelRegistry, LanguageModelRequestTool}; +use language_model::LanguageModelRegistry; use language_model_selector::LanguageModelSelector; use rope::Point; use settings::Settings; @@ -205,22 +205,7 @@ impl MessageEditor { .update(&mut cx, |thread, cx| { let context = context_store.read(cx).snapshot(cx).collect::>(); thread.insert_user_message(user_message, context, cx); - let mut request = thread.to_completion_request(request_kind, cx); - - if use_tools { - request.tools = thread - .tools() - .tools(cx) - .into_iter() - .map(|tool| LanguageModelRequestTool { - name: tool.name(), - description: tool.description(), - input_schema: tool.input_schema(), - }) - .collect(); - } - - thread.stream_completion(request, model, cx) + thread.send_to_model(model, request_kind, use_tools, cx); }) .ok(); }) diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 47c1222807033abae2904462de796d6dedcb2ec4..2b519194540fcfe5ff48911cd37b01fee2de10db 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -9,9 +9,9 @@ use futures::{FutureExt as _, StreamExt as _}; use gpui::{App, Context, EventEmitter, SharedString, Task}; use language_model::{ LanguageModel, LanguageModelCompletionEvent, LanguageModelRegistry, LanguageModelRequest, - LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolUse, - LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, PaymentRequiredError, - Role, StopReason, + LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult, + LanguageModelToolUse, LanguageModelToolUseId, MaxMonthlySpendReachedError, MessageContent, + PaymentRequiredError, Role, StopReason, }; use serde::{Deserialize, Serialize}; use util::{post_inc, TryFutureExt as _}; @@ -314,6 +314,31 @@ impl Thread { text } + pub fn send_to_model( + &mut self, + model: Arc, + request_kind: RequestKind, + use_tools: bool, + cx: &mut Context, + ) { + let mut request = self.to_completion_request(request_kind, cx); + + if use_tools { + request.tools = self + .tools() + .tools(cx) + .into_iter() + .map(|tool| LanguageModelRequestTool { + name: tool.name(), + description: tool.description(), + input_schema: tool.input_schema(), + }) + .collect(); + } + + self.stream_completion(request, model, cx); + } + pub fn to_completion_request( &self, request_kind: RequestKind,