1use anyhow::Result;
2use assistant_tooling::ToolFunctionDefinition;
3use client::{proto, Client};
4use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
5use gpui::Global;
6use std::sync::Arc;
7
8pub use open_ai::RequestMessage as CompletionMessage;
9
10#[derive(Clone)]
11pub struct CompletionProvider(Arc<dyn CompletionProviderBackend>);
12
13impl CompletionProvider {
14 pub fn new(backend: impl CompletionProviderBackend) -> Self {
15 Self(Arc::new(backend))
16 }
17
18 pub fn default_model(&self) -> String {
19 self.0.default_model()
20 }
21
22 pub fn available_models(&self) -> Vec<String> {
23 self.0.available_models()
24 }
25
26 pub fn complete(
27 &self,
28 model: String,
29 messages: Vec<CompletionMessage>,
30 stop: Vec<String>,
31 temperature: f32,
32 tools: &[ToolFunctionDefinition],
33 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
34 {
35 self.0.complete(model, messages, stop, temperature, tools)
36 }
37}
38
39impl Global for CompletionProvider {}
40
41pub trait CompletionProviderBackend: 'static {
42 fn default_model(&self) -> String;
43 fn available_models(&self) -> Vec<String>;
44 fn complete(
45 &self,
46 model: String,
47 messages: Vec<CompletionMessage>,
48 stop: Vec<String>,
49 temperature: f32,
50 tools: &[ToolFunctionDefinition],
51 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>;
52}
53
54pub struct CloudCompletionProvider {
55 client: Arc<Client>,
56}
57
58impl CloudCompletionProvider {
59 pub fn new(client: Arc<Client>) -> Self {
60 Self { client }
61 }
62}
63
64impl CompletionProviderBackend for CloudCompletionProvider {
65 fn default_model(&self) -> String {
66 "gpt-4-turbo".into()
67 }
68
69 fn available_models(&self) -> Vec<String> {
70 vec!["gpt-4-turbo".into(), "gpt-4".into(), "gpt-3.5-turbo".into()]
71 }
72
73 fn complete(
74 &self,
75 model: String,
76 messages: Vec<CompletionMessage>,
77 stop: Vec<String>,
78 temperature: f32,
79 tools: &[ToolFunctionDefinition],
80 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<proto::LanguageModelResponseMessage>>>>
81 {
82 let client = self.client.clone();
83 let tools: Vec<proto::ChatCompletionTool> = tools
84 .iter()
85 .filter_map(|tool| {
86 Some(proto::ChatCompletionTool {
87 variant: Some(proto::chat_completion_tool::Variant::Function(
88 proto::chat_completion_tool::FunctionObject {
89 name: tool.name.clone(),
90 description: Some(tool.description.clone()),
91 parameters: Some(serde_json::to_string(&tool.parameters).ok()?),
92 },
93 )),
94 })
95 })
96 .collect();
97
98 let tool_choice = match tools.is_empty() {
99 true => None,
100 false => Some("auto".into()),
101 };
102
103 async move {
104 let stream = client
105 .request_stream(proto::CompleteWithLanguageModel {
106 model,
107 messages: messages
108 .into_iter()
109 .map(|message| match message {
110 CompletionMessage::Assistant {
111 content,
112 tool_calls,
113 } => proto::LanguageModelRequestMessage {
114 role: proto::LanguageModelRole::LanguageModelAssistant as i32,
115 content: content.unwrap_or_default(),
116 tool_call_id: None,
117 tool_calls: tool_calls
118 .into_iter()
119 .map(|tool_call| match tool_call.content {
120 open_ai::ToolCallContent::Function { function } => {
121 proto::ToolCall {
122 id: tool_call.id,
123 variant: Some(proto::tool_call::Variant::Function(
124 proto::tool_call::FunctionCall {
125 name: function.name,
126 arguments: function.arguments,
127 },
128 )),
129 }
130 }
131 })
132 .collect(),
133 },
134 CompletionMessage::User { content } => {
135 proto::LanguageModelRequestMessage {
136 role: proto::LanguageModelRole::LanguageModelUser as i32,
137 content,
138 tool_call_id: None,
139 tool_calls: Vec::new(),
140 }
141 }
142 CompletionMessage::System { content } => {
143 proto::LanguageModelRequestMessage {
144 role: proto::LanguageModelRole::LanguageModelSystem as i32,
145 content,
146 tool_calls: Vec::new(),
147 tool_call_id: None,
148 }
149 }
150 CompletionMessage::Tool {
151 content,
152 tool_call_id,
153 } => proto::LanguageModelRequestMessage {
154 role: proto::LanguageModelRole::LanguageModelTool as i32,
155 content,
156 tool_call_id: Some(tool_call_id),
157 tool_calls: Vec::new(),
158 },
159 })
160 .collect(),
161 stop,
162 temperature,
163 tool_choice,
164 tools,
165 })
166 .await?;
167
168 Ok(stream
169 .filter_map(|response| async move {
170 match response {
171 Ok(mut response) => Some(Ok(response.choices.pop()?.delta?)),
172 Err(error) => Some(Err(error)),
173 }
174 })
175 .boxed())
176 }
177 .boxed()
178 }
179}