1use crate::{
2 assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
3};
4use anyhow::Result;
5use futures::StreamExt as _;
6use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
7use gpui::{AnyView, AppContext, Task};
8use http::HttpClient;
9use ollama::{
10 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
11 Role as OllamaRole,
12};
13use std::sync::Arc;
14use std::time::Duration;
15use ui::{prelude::*, ButtonLike, ElevationIndex};
16
17const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
18
19pub struct OllamaCompletionProvider {
20 api_url: String,
21 model: OllamaModel,
22 http_client: Arc<dyn HttpClient>,
23 low_speed_timeout: Option<Duration>,
24 settings_version: usize,
25 available_models: Vec<OllamaModel>,
26}
27
28impl OllamaCompletionProvider {
29 pub fn new(
30 model: OllamaModel,
31 api_url: String,
32 http_client: Arc<dyn HttpClient>,
33 low_speed_timeout: Option<Duration>,
34 settings_version: usize,
35 cx: &AppContext,
36 ) -> Self {
37 cx.spawn({
38 let api_url = api_url.clone();
39 let client = http_client.clone();
40 let model = model.name.clone();
41
42 |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
43 })
44 .detach_and_log_err(cx);
45
46 Self {
47 api_url,
48 model,
49 http_client,
50 low_speed_timeout,
51 settings_version,
52 available_models: Default::default(),
53 }
54 }
55
56 pub fn update(
57 &mut self,
58 model: OllamaModel,
59 api_url: String,
60 low_speed_timeout: Option<Duration>,
61 settings_version: usize,
62 cx: &AppContext,
63 ) {
64 cx.spawn({
65 let api_url = api_url.clone();
66 let client = self.http_client.clone();
67 let model = model.name.clone();
68
69 |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
70 })
71 .detach_and_log_err(cx);
72
73 self.model = model;
74 self.api_url = api_url;
75 self.low_speed_timeout = low_speed_timeout;
76 self.settings_version = settings_version;
77 }
78
79 pub fn available_models(&self) -> impl Iterator<Item = &OllamaModel> {
80 self.available_models.iter()
81 }
82
83 pub fn settings_version(&self) -> usize {
84 self.settings_version
85 }
86
87 pub fn is_authenticated(&self) -> bool {
88 !self.available_models.is_empty()
89 }
90
91 pub fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
92 if self.is_authenticated() {
93 Task::ready(Ok(()))
94 } else {
95 self.fetch_models(cx)
96 }
97 }
98
99 pub fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
100 self.fetch_models(cx)
101 }
102
103 pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
104 let http_client = self.http_client.clone();
105 let api_url = self.api_url.clone();
106
107 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
108 cx.spawn(|mut cx| async move {
109 let models = get_models(http_client.as_ref(), &api_url, None).await?;
110
111 let mut models: Vec<OllamaModel> = models
112 .into_iter()
113 // Since there is no metadata from the Ollama API
114 // indicating which models are embedding models,
115 // simply filter out models with "-embed" in their name
116 .filter(|model| !model.name.contains("-embed"))
117 .map(|model| OllamaModel::new(&model.name))
118 .collect();
119
120 models.sort_by(|a, b| a.name.cmp(&b.name));
121
122 cx.update_global::<CompletionProvider, _>(|provider, _cx| {
123 if let CompletionProvider::Ollama(provider) = provider {
124 provider.available_models = models;
125 }
126 })
127 })
128 }
129
130 pub fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
131 cx.new_view(|cx| DownloadOllamaMessage::new(cx)).into()
132 }
133
134 pub fn model(&self) -> OllamaModel {
135 self.model.clone()
136 }
137
138 pub fn count_tokens(
139 &self,
140 request: LanguageModelRequest,
141 _cx: &AppContext,
142 ) -> BoxFuture<'static, Result<usize>> {
143 // There is no endpoint for this _yet_ in Ollama
144 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
145 let token_count = request
146 .messages
147 .iter()
148 .map(|msg| msg.content.chars().count())
149 .sum::<usize>()
150 / 4;
151
152 async move { Ok(token_count) }.boxed()
153 }
154
155 pub fn complete(
156 &self,
157 request: LanguageModelRequest,
158 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
159 let request = self.to_ollama_request(request);
160
161 let http_client = self.http_client.clone();
162 let api_url = self.api_url.clone();
163 let low_speed_timeout = self.low_speed_timeout;
164 async move {
165 let request =
166 stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
167 let response = request.await?;
168 let stream = response
169 .filter_map(|response| async move {
170 match response {
171 Ok(delta) => {
172 let content = match delta.message {
173 ChatMessage::User { content } => content,
174 ChatMessage::Assistant { content } => content,
175 ChatMessage::System { content } => content,
176 };
177 Some(Ok(content))
178 }
179 Err(error) => Some(Err(error)),
180 }
181 })
182 .boxed();
183 Ok(stream)
184 }
185 .boxed()
186 }
187
188 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
189 let model = match request.model {
190 LanguageModel::Ollama(model) => model,
191 _ => self.model(),
192 };
193
194 ChatRequest {
195 model: model.name,
196 messages: request
197 .messages
198 .into_iter()
199 .map(|msg| match msg.role {
200 Role::User => ChatMessage::User {
201 content: msg.content,
202 },
203 Role::Assistant => ChatMessage::Assistant {
204 content: msg.content,
205 },
206 Role::System => ChatMessage::System {
207 content: msg.content,
208 },
209 })
210 .collect(),
211 keep_alive: model.keep_alive,
212 stream: true,
213 options: Some(ChatOptions {
214 num_ctx: Some(model.max_tokens),
215 stop: Some(request.stop),
216 temperature: Some(request.temperature),
217 ..Default::default()
218 }),
219 }
220 }
221}
222
223impl From<Role> for ollama::Role {
224 fn from(val: Role) -> Self {
225 match val {
226 Role::User => OllamaRole::User,
227 Role::Assistant => OllamaRole::Assistant,
228 Role::System => OllamaRole::System,
229 }
230 }
231}
232
233struct DownloadOllamaMessage {}
234
235impl DownloadOllamaMessage {
236 pub fn new(_cx: &mut ViewContext<Self>) -> Self {
237 Self {}
238 }
239
240 fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
241 ButtonLike::new("download_ollama_button")
242 .style(ButtonStyle::Filled)
243 .size(ButtonSize::Large)
244 .layer(ElevationIndex::ModalSurface)
245 .child(Label::new("Get Ollama"))
246 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
247 }
248}
249
250impl Render for DownloadOllamaMessage {
251 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
252 v_flex()
253 .p_4()
254 .size_full()
255 .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine.").size(LabelSize::Large))
256 .child(
257 h_flex()
258 .w_full()
259 .p_4()
260 .justify_center()
261 .child(
262 self.render_download_button(cx)
263 )
264 )
265 .into_any()
266 }
267}