@@ -543,7 +543,7 @@ impl Assistant {
fn count_remaining_tokens(&mut self, cx: &mut ModelContext<Self>) {
let messages = self
- .open_ai_request_messages(cx)
+ .messages(cx)
.into_iter()
.filter_map(|message| {
Some(tiktoken_rs::ChatCompletionRequestMessage {
@@ -552,7 +552,7 @@ impl Assistant {
Role::Assistant => "assistant".into(),
Role::System => "system".into(),
},
- content: message.content,
+ content: self.buffer.read(cx).text_for_range(message.range).collect(),
name: None,
})
})
@@ -596,7 +596,10 @@ impl Assistant {
) -> Option<(MessageAnchor, MessageAnchor)> {
let request = OpenAIRequest {
model: self.model.clone(),
- messages: self.open_ai_request_messages(cx),
+ messages: self
+ .messages(cx)
+ .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+ .collect(),
stream: true,
};
@@ -841,16 +844,19 @@ impl Assistant {
if self.message_anchors.len() >= 2 && self.summary.is_none() {
let api_key = self.api_key.borrow().clone();
if let Some(api_key) = api_key {
- let mut messages = self.open_ai_request_messages(cx);
- messages.truncate(2);
- messages.push(RequestMessage {
- role: Role::User,
- content: "Summarize the conversation into a short title without punctuation"
- .into(),
- });
+ let messages = self
+ .messages(cx)
+ .take(2)
+ .map(|message| message.to_open_ai_message(self.buffer.read(cx)))
+ .chain(Some(RequestMessage {
+ role: Role::User,
+ content:
+ "Summarize the conversation into a short title without punctuation"
+ .into(),
+ }));
let request = OpenAIRequest {
model: self.model.clone(),
- messages,
+ messages: messages.collect(),
stream: true,
};
@@ -878,16 +884,6 @@ impl Assistant {
}
}
- fn open_ai_request_messages(&self, cx: &AppContext) -> Vec<RequestMessage> {
- let buffer = self.buffer.read(cx);
- self.messages(cx)
- .map(|message| RequestMessage {
- role: message.role,
- content: buffer.text_for_range(message.range).collect(),
- })
- .collect()
- }
-
fn message_for_offset<'a>(&'a self, offset: usize, cx: &'a AppContext) -> Option<Message> {
let mut messages = self.messages(cx).peekable();
while let Some(message) = messages.next() {
@@ -1446,6 +1442,15 @@ pub struct Message {
error: Option<Arc<str>>,
}
+impl Message {
+ fn to_open_ai_message(&self, buffer: &Buffer) -> RequestMessage {
+ RequestMessage {
+ role: self.role,
+ content: buffer.text_for_range(self.range.clone()).collect(),
+ }
+ }
+}
+
async fn stream_completion(
api_key: String,
executor: Arc<Background>,