completion_provider.rs

  1use anyhow::Result;
  2use assistant_tooling::ToolFunctionDefinition;
  3use client::{proto, Client};
  4use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  5use gpui::{AppContext, Global};
  6use std::sync::Arc;
  7
  8pub use open_ai::RequestMessage as CompletionMessage;
  9
 10#[derive(Clone)]
 11pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
 12
 13impl CompletionProvider {
 14    pub fn get(cx: &AppContext) -> &Self {
 15        cx.global::<CompletionProvider>()
 16    }
 17
 18    pub fn new(backend: impl CompletionProviderBackend) -> Self {
 19        Self(Arc::new(backend))
 20    }
 21
 22    pub fn default_model(&self) -> String {
 23        self.0.default_model()
 24    }
 25
 26    pub fn available_models(&self) -> Vec<String> {
 27        self.0.available_models()
 28    }
 29
 30    pub fn complete(
 31        &self,
 32        model: String,
 33        messages: Vec<CompletionMessage>,
 34        stop: Vec<String>,
 35        temperature: f32,
 36        tools: Vec<ToolFunctionDefinition>,
 37    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
 38    {
 39        self.0.complete(model, messages, stop, temperature, tools)
 40    }
 41}
 42
 43impl Global for CompletionProvider {}
 44
 45pub trait CompletionProviderBackend: 'static {
 46    fn default_model(&self) -> String;
 47    fn available_models(&self) -> Vec<String>;
 48    fn complete(
 49        &self,
 50        model: String,
 51        messages: Vec<CompletionMessage>,
 52        stop: Vec<String>,
 53        temperature: f32,
 54        tools: Vec<ToolFunctionDefinition>,
 55    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
 56}
 57
 58pub struct CloudCompletionProvider {
 59    client: Arc<Client>,
 60}
 61
 62impl CloudCompletionProvider {
 63    pub fn new(client: Arc<Client>) -> Self {
 64        Self { client }
 65    }
 66}
 67
 68impl CompletionProviderBackend for CloudCompletionProvider {
 69    fn default_model(&self) -> String {
 70        "gpt-4-turbo".into()
 71    }
 72
 73    fn available_models(&self) -> Vec<String> {
 74        vec!["gpt-4-turbo".into(), "gpt-4".into(), "gpt-3.5-turbo".into()]
 75    }
 76
 77    fn complete(
 78        &self,
 79        model: String,
 80        messages: Vec<CompletionMessage>,
 81        stop: Vec<String>,
 82        temperature: f32,
 83        tools: Vec<ToolFunctionDefinition>,
 84    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
 85    {
 86        let client = self.client.clone();
 87        let tools: Vec<proto::ChatCompletionTool> = tools
 88            .iter()
 89            .filter_map(|tool| {
 90                Some(proto::ChatCompletionTool {
 91                    variant: Some(proto::chat_completion_tool::Variant::Function(
 92                        proto::chat_completion_tool::FunctionObject {
 93                            name: tool.name.clone(),
 94                            description: Some(tool.description.clone()),
 95                            parameters: Some(serde_json::to_string(&tool.parameters).ok()?),
 96                        },
 97                    )),
 98                })
 99            })
100            .collect();
101
102        let tool_choice = match tools.is_empty() {
103            true => None,
104            false => Some("auto".into()),
105        };
106
107        async move {
108            let stream = client
109                .request_stream(proto::CompleteWithLanguageModel {
110                    model,
111                    messages: messages
112                        .into_iter()
113                        .map(|message| match message {
114                            CompletionMessage::Assistant {
115                                content,
116                                tool_calls,
117                            } => proto::LanguageModelRequestMessage {
118                                role: proto::LanguageModelRole::LanguageModelAssistant as i32,
119                                content: content.unwrap_or_default(),
120                                tool_call_id: None,
121                                tool_calls: tool_calls
122                                    .into_iter()
123                                    .map(|tool_call| match tool_call.content {
124                                        open_ai::ToolCallContent::Function { function } => {
125                                            proto::ToolCall {
126                                                id: tool_call.id,
127                                                variant: Some(proto::tool_call::Variant::Function(
128                                                    proto::tool_call::FunctionCall {
129                                                        name: function.name,
130                                                        arguments: function.arguments,
131                                                    },
132                                                )),
133                                            }
134                                        }
135                                    })
136                                    .collect(),
137                            },
138                            CompletionMessage::User { content } => {
139                                proto::LanguageModelRequestMessage {
140                                    role: proto::LanguageModelRole::LanguageModelUser as i32,
141                                    content,
142                                    tool_call_id: None,
143                                    tool_calls: Vec::new(),
144                                }
145                            }
146                            CompletionMessage::System { content } => {
147                                proto::LanguageModelRequestMessage {
148                                    role: proto::LanguageModelRole::LanguageModelSystem as i32,
149                                    content,
150                                    tool_calls: Vec::new(),
151                                    tool_call_id: None,
152                                }
153                            }
154                            CompletionMessage::Tool {
155                                content,
156                                tool_call_id,
157                            } => proto::LanguageModelRequestMessage {
158                                role: proto::LanguageModelRole::LanguageModelTool as i32,
159                                content,
160                                tool_call_id: Some(tool_call_id),
161                                tool_calls: Vec::new(),
162                            },
163                        })
164                        .collect(),
165                    stop,
166                    temperature,
167                    tool_choice,
168                    tools,
169                })
170                .await?;
171
172            Ok(stream
173                .filter_map(|response| async move {
174                    match response {
175                        Ok(mut response) => Some(Ok(response.choices.pop()?.delta?)),
176                        Err(error) => Some(Err(error)),
177                    }
178                })
179                .boxed())
180        }
181        .boxed()
182    }
183}