@@ -919,40 +919,10 @@ impl Zeta {
.ok();
}
- let (mut res, usage) = response?;
-
+ let (res, usage) = response?;
let request_id = EditPredictionId(Uuid::from_str(&res.id)?);
-
- let Some(choice) = res.choices.pop() else {
- return Ok((None, usage));
- };
-
- let output_text = match choice.message {
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Plain(content)),
- ..
- } => content,
- open_ai::RequestMessage::Assistant {
- content: Some(open_ai::MessageContent::Multipart(mut content)),
- ..
- } => {
- if content.is_empty() {
- log::error!("No output from Baseten completion response");
- return Ok((None, usage));
- }
-
- match content.remove(0) {
- open_ai::MessagePart::Text { text } => text,
- open_ai::MessagePart::Image { .. } => {
- log::error!("Expected text, got an image");
- return Ok((None, usage));
- }
- }
- }
- _ => {
- log::error!("Invalid response message: {:?}", choice.message);
- return Ok((None, usage));
- }
+ let Some(output_text) = text_from_response(res) else {
+ return Ok((None, usage))
};
let (edited_buffer_snapshot, edits) =
@@ -1216,6 +1186,10 @@ impl Zeta {
return Task::ready(Ok(()));
};
+ let app_version = AppVersion::global(cx);
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let debug_tx = self.debug_tx.clone();
let current_file_path: Arc<Path> = snapshot
.file()
.map(|f| f.full_path(cx).into())
@@ -1240,10 +1214,17 @@ impl Zeta {
}
};
- let app_version = AppVersion::global(cx);
- let client = self.client.clone();
- let llm_token = self.llm_token.clone();
- let debug_tx = self.debug_tx.clone();
+ if let Some(debug_tx) = &debug_tx {
+ debug_tx
+ .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
+ ZetaContextRetrievalStartedDebugInfo {
+ project: project.clone(),
+ timestamp: Instant::now(),
+ search_prompt: prompt.clone(),
+ },
+ ))
+ .ok();
+ }
let (tool_schema, tool_description) = &*cloud_zeta2_prompt::retrieval_prompt::TOOL_SCHEMA;
@@ -1503,6 +1484,38 @@ impl Zeta {
}
}
+pub fn text_from_response(mut res: open_ai::Response) -> Option<String> {
+ let choice = res.choices.pop()?;
+ let output_text = match choice.message {
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(content)),
+ ..
+ } => content,
+ open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Multipart(mut content)),
+ ..
+ } => {
+ if content.is_empty() {
+ log::error!("No output from Baseten completion response");
+ return None;
+ }
+
+ match content.remove(0) {
+ open_ai::MessagePart::Text { text } => text,
+ open_ai::MessagePart::Image { .. } => {
+ log::error!("Expected text, got an image");
+ return None;
+ }
+ }
+ }
+ _ => {
+ log::error!("Invalid response message: {:?}", choice.message);
+ return None;
+ }
+ };
+ Some(output_text)
+}
+
#[derive(Error, Debug)]
#[error(
"You must update to Zed version {minimum_version} or higher to continue using edit predictions."
@@ -166,9 +166,7 @@ impl NamedExample {
REVISION_FIELD => {
named.example.revision = value.trim().to_string();
}
- _ => {
- eprintln!("Warning: Unrecognized field `{field}`");
- }
+ _ => {}
}
}
}
@@ -193,7 +191,6 @@ impl NamedExample {
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
Section::ExpectedExcerpts
} else {
- eprintln!("Warning: Unrecognized section `{title:?}`");
Section::Other
};
}
@@ -148,14 +148,6 @@ pub async fn zeta2_predict(
&request.local_prompt.unwrap_or_default(),
)?;
- let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
- prediction_finished_at = Some(Instant::now());
-
- fs::write(
- LOGS_DIR.join("prediction_response.json"),
- &serde_json::to_string_pretty(&response).unwrap(),
- )?;
-
for included_file in request.request.included_files {
let insertions = vec![(request.request.cursor_point, CURSOR_MARKER)];
result
@@ -177,6 +169,12 @@ pub async fn zeta2_predict(
&mut excerpts_text,
);
}
+
+ let response = request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
+ let response = zeta2::text_from_response(response).unwrap_or_default();
+ prediction_finished_at = Some(Instant::now());
+ fs::write(LOGS_DIR.join("prediction_response.md"), &response)?;
+
break;
}
}