1use anyhow::{anyhow, bail, Result};
2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
4use http_client::HttpClient;
5use ollama::{
6 get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
7 ChatResponseDelta, KeepAlive, OllamaToolCall,
8};
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use settings::{Settings, SettingsStore};
12use std::{collections::BTreeMap, sync::Arc};
13use ui::{prelude::*, ButtonLike, Indicator};
14use util::ResultExt;
15
16use crate::LanguageModelCompletionEvent;
17use crate::{
18 settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
19 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
20 LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
21};
22
23const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
24const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
25const OLLAMA_SITE: &str = "https://ollama.com/";
26
27const PROVIDER_ID: &str = "ollama";
28const PROVIDER_NAME: &str = "Ollama";
29
30#[derive(Default, Debug, Clone, PartialEq)]
31pub struct OllamaSettings {
32 pub api_url: String,
33 pub available_models: Vec<AvailableModel>,
34}
35
36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
37pub struct AvailableModel {
38 /// The model name in the Ollama API (e.g. "llama3.2:latest")
39 pub name: String,
40 /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
41 pub display_name: Option<String>,
42 /// The Context Length parameter to the model (aka num_ctx or n_ctx)
43 pub max_tokens: usize,
44 /// The number of seconds to keep the connection open after the last request
45 pub keep_alive: Option<KeepAlive>,
46}
47
48pub struct OllamaLanguageModelProvider {
49 http_client: Arc<dyn HttpClient>,
50 state: gpui::Model<State>,
51}
52
53pub struct State {
54 http_client: Arc<dyn HttpClient>,
55 available_models: Vec<ollama::Model>,
56 fetch_model_task: Option<Task<Result<()>>>,
57 _subscription: Subscription,
58}
59
60impl State {
61 fn is_authenticated(&self) -> bool {
62 !self.available_models.is_empty()
63 }
64
65 fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
66 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
67 let http_client = self.http_client.clone();
68 let api_url = settings.api_url.clone();
69
70 // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
71 cx.spawn(|this, mut cx| async move {
72 let models = get_models(http_client.as_ref(), &api_url, None).await?;
73
74 let mut models: Vec<ollama::Model> = models
75 .into_iter()
76 // Since there is no metadata from the Ollama API
77 // indicating which models are embedding models,
78 // simply filter out models with "-embed" in their name
79 .filter(|model| !model.name.contains("-embed"))
80 .map(|model| ollama::Model::new(&model.name, None, None))
81 .collect();
82
83 models.sort_by(|a, b| a.name.cmp(&b.name));
84
85 this.update(&mut cx, |this, cx| {
86 this.available_models = models;
87 cx.notify();
88 })
89 })
90 }
91
92 fn restart_fetch_models_task(&mut self, cx: &mut ModelContext<Self>) {
93 let task = self.fetch_models(cx);
94 self.fetch_model_task.replace(task);
95 }
96
97 fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
98 if self.is_authenticated() {
99 Task::ready(Ok(()))
100 } else {
101 self.fetch_models(cx)
102 }
103 }
104}
105
106impl OllamaLanguageModelProvider {
107 pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
108 let this = Self {
109 http_client: http_client.clone(),
110 state: cx.new_model(|cx| {
111 let subscription = cx.observe_global::<SettingsStore>({
112 let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
113 move |this: &mut State, cx| {
114 let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
115 if &settings != new_settings {
116 settings = new_settings.clone();
117 this.restart_fetch_models_task(cx);
118 cx.notify();
119 }
120 }
121 });
122
123 State {
124 http_client,
125 available_models: Default::default(),
126 fetch_model_task: None,
127 _subscription: subscription,
128 }
129 }),
130 };
131 this.state
132 .update(cx, |state, cx| state.restart_fetch_models_task(cx));
133 this
134 }
135}
136
137impl LanguageModelProviderState for OllamaLanguageModelProvider {
138 type ObservableEntity = State;
139
140 fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
141 Some(self.state.clone())
142 }
143}
144
145impl LanguageModelProvider for OllamaLanguageModelProvider {
146 fn id(&self) -> LanguageModelProviderId {
147 LanguageModelProviderId(PROVIDER_ID.into())
148 }
149
150 fn name(&self) -> LanguageModelProviderName {
151 LanguageModelProviderName(PROVIDER_NAME.into())
152 }
153
154 fn icon(&self) -> IconName {
155 IconName::AiOllama
156 }
157
158 fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
159 let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
160
161 // Add models from the Ollama API
162 for model in self.state.read(cx).available_models.iter() {
163 models.insert(model.name.clone(), model.clone());
164 }
165
166 // Override with available models from settings
167 for model in AllLanguageModelSettings::get_global(cx)
168 .ollama
169 .available_models
170 .iter()
171 {
172 models.insert(
173 model.name.clone(),
174 ollama::Model {
175 name: model.name.clone(),
176 display_name: model.display_name.clone(),
177 max_tokens: model.max_tokens,
178 keep_alive: model.keep_alive.clone(),
179 },
180 );
181 }
182
183 models
184 .into_values()
185 .map(|model| {
186 Arc::new(OllamaLanguageModel {
187 id: LanguageModelId::from(model.name.clone()),
188 model: model.clone(),
189 http_client: self.http_client.clone(),
190 request_limiter: RateLimiter::new(4),
191 }) as Arc<dyn LanguageModel>
192 })
193 .collect()
194 }
195
196 fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
197 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
198 let http_client = self.http_client.clone();
199 let api_url = settings.api_url.clone();
200 let id = model.id().0.to_string();
201 cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
202 .detach_and_log_err(cx);
203 }
204
205 fn is_authenticated(&self, cx: &AppContext) -> bool {
206 self.state.read(cx).is_authenticated()
207 }
208
209 fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
210 self.state.update(cx, |state, cx| state.authenticate(cx))
211 }
212
213 fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
214 let state = self.state.clone();
215 cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
216 }
217
218 fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
219 self.state.update(cx, |state, cx| state.fetch_models(cx))
220 }
221}
222
223pub struct OllamaLanguageModel {
224 id: LanguageModelId,
225 model: ollama::Model,
226 http_client: Arc<dyn HttpClient>,
227 request_limiter: RateLimiter,
228}
229
230impl OllamaLanguageModel {
231 fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
232 ChatRequest {
233 model: self.model.name.clone(),
234 messages: request
235 .messages
236 .into_iter()
237 .map(|msg| match msg.role {
238 Role::User => ChatMessage::User {
239 content: msg.string_contents(),
240 },
241 Role::Assistant => ChatMessage::Assistant {
242 content: msg.string_contents(),
243 tool_calls: None,
244 },
245 Role::System => ChatMessage::System {
246 content: msg.string_contents(),
247 },
248 })
249 .collect(),
250 keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
251 stream: true,
252 options: Some(ChatOptions {
253 num_ctx: Some(self.model.max_tokens),
254 stop: Some(request.stop),
255 temperature: request.temperature.or(Some(1.0)),
256 ..Default::default()
257 }),
258 tools: vec![],
259 }
260 }
261 fn request_completion(
262 &self,
263 request: ChatRequest,
264 cx: &AsyncAppContext,
265 ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
266 let http_client = self.http_client.clone();
267
268 let Ok(api_url) = cx.update(|cx| {
269 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
270 settings.api_url.clone()
271 }) else {
272 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
273 };
274
275 async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
276 }
277}
278
279impl LanguageModel for OllamaLanguageModel {
280 fn id(&self) -> LanguageModelId {
281 self.id.clone()
282 }
283
284 fn name(&self) -> LanguageModelName {
285 LanguageModelName::from(self.model.display_name().to_string())
286 }
287
288 fn provider_id(&self) -> LanguageModelProviderId {
289 LanguageModelProviderId(PROVIDER_ID.into())
290 }
291
292 fn provider_name(&self) -> LanguageModelProviderName {
293 LanguageModelProviderName(PROVIDER_NAME.into())
294 }
295
296 fn telemetry_id(&self) -> String {
297 format!("ollama/{}", self.model.id())
298 }
299
300 fn max_token_count(&self) -> usize {
301 self.model.max_token_count()
302 }
303
304 fn count_tokens(
305 &self,
306 request: LanguageModelRequest,
307 _cx: &AppContext,
308 ) -> BoxFuture<'static, Result<usize>> {
309 // There is no endpoint for this _yet_ in Ollama
310 // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
311 let token_count = request
312 .messages
313 .iter()
314 .map(|msg| msg.string_contents().chars().count())
315 .sum::<usize>()
316 / 4;
317
318 async move { Ok(token_count) }.boxed()
319 }
320
321 fn stream_completion(
322 &self,
323 request: LanguageModelRequest,
324 cx: &AsyncAppContext,
325 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
326 let request = self.to_ollama_request(request);
327
328 let http_client = self.http_client.clone();
329 let Ok(api_url) = cx.update(|cx| {
330 let settings = &AllLanguageModelSettings::get_global(cx).ollama;
331 settings.api_url.clone()
332 }) else {
333 return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
334 };
335
336 let future = self.request_limiter.stream(async move {
337 let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
338 let stream = response
339 .filter_map(|response| async move {
340 match response {
341 Ok(delta) => {
342 let content = match delta.message {
343 ChatMessage::User { content } => content,
344 ChatMessage::Assistant { content, .. } => content,
345 ChatMessage::System { content } => content,
346 };
347 Some(Ok(content))
348 }
349 Err(error) => Some(Err(error)),
350 }
351 })
352 .boxed();
353 Ok(stream)
354 });
355
356 async move {
357 Ok(future
358 .await?
359 .map(|result| result.map(LanguageModelCompletionEvent::Text))
360 .boxed())
361 }
362 .boxed()
363 }
364
365 fn use_any_tool(
366 &self,
367 request: LanguageModelRequest,
368 tool_name: String,
369 tool_description: String,
370 schema: serde_json::Value,
371 cx: &AsyncAppContext,
372 ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
373 use ollama::{OllamaFunctionTool, OllamaTool};
374 let function = OllamaFunctionTool {
375 name: tool_name.clone(),
376 description: Some(tool_description),
377 parameters: Some(schema),
378 };
379 let tools = vec![OllamaTool::Function { function }];
380 let request = self.to_ollama_request(request).with_tools(tools);
381 let response = self.request_completion(request, cx);
382 self.request_limiter
383 .run(async move {
384 let response = response.await?;
385 let ChatMessage::Assistant { tool_calls, .. } = response.message else {
386 bail!("message does not have an assistant role");
387 };
388 if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
389 for call in tool_calls {
390 let OllamaToolCall::Function(function) = call;
391 if function.name == tool_name {
392 return Ok(futures::stream::once(async move {
393 Ok(function.arguments.to_string())
394 })
395 .boxed());
396 }
397 }
398 } else {
399 bail!("assistant message does not have any tool calls");
400 };
401
402 bail!("tool not used")
403 })
404 .boxed()
405 }
406}
407
408struct ConfigurationView {
409 state: gpui::Model<State>,
410 loading_models_task: Option<Task<()>>,
411}
412
413impl ConfigurationView {
414 pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
415 let loading_models_task = Some(cx.spawn({
416 let state = state.clone();
417 |this, mut cx| async move {
418 if let Some(task) = state
419 .update(&mut cx, |state, cx| state.authenticate(cx))
420 .log_err()
421 {
422 task.await.log_err();
423 }
424 this.update(&mut cx, |this, cx| {
425 this.loading_models_task = None;
426 cx.notify();
427 })
428 .log_err();
429 }
430 }));
431
432 Self {
433 state,
434 loading_models_task,
435 }
436 }
437
438 fn retry_connection(&self, cx: &mut WindowContext) {
439 self.state
440 .update(cx, |state, cx| state.fetch_models(cx))
441 .detach_and_log_err(cx);
442 }
443}
444
445impl Render for ConfigurationView {
446 fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
447 let is_authenticated = self.state.read(cx).is_authenticated();
448
449 let ollama_intro = "Get up and running with Llama 3.2, Mistral, Gemma 2, and other large language models with Ollama.";
450 let ollama_reqs =
451 "Ollama must be running with at least one model installed to use it in the assistant.";
452
453 let mut inline_code_bg = cx.theme().colors().editor_background;
454 inline_code_bg.fade_out(0.5);
455
456 if self.loading_models_task.is_some() {
457 div().child(Label::new("Loading models...")).into_any()
458 } else {
459 v_flex()
460 .size_full()
461 .gap_3()
462 .child(
463 v_flex()
464 .size_full()
465 .gap_2()
466 .p_1()
467 .child(Label::new(ollama_intro))
468 .child(Label::new(ollama_reqs))
469 .child(
470 h_flex()
471 .gap_0p5()
472 .child(Label::new("Once installed, try "))
473 .child(
474 div()
475 .bg(inline_code_bg)
476 .px_1p5()
477 .rounded_md()
478 .child(Label::new("ollama run llama3.2")),
479 ),
480 ),
481 )
482 .child(
483 h_flex()
484 .w_full()
485 .pt_2()
486 .justify_between()
487 .gap_2()
488 .child(
489 h_flex()
490 .w_full()
491 .gap_2()
492 .map(|this| {
493 if is_authenticated {
494 this.child(
495 Button::new("ollama-site", "Ollama")
496 .style(ButtonStyle::Subtle)
497 .icon(IconName::ExternalLink)
498 .icon_size(IconSize::XSmall)
499 .icon_color(Color::Muted)
500 .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
501 .into_any_element(),
502 )
503 } else {
504 this.child(
505 Button::new(
506 "download_ollama_button",
507 "Download Ollama",
508 )
509 .style(ButtonStyle::Subtle)
510 .icon(IconName::ExternalLink)
511 .icon_size(IconSize::XSmall)
512 .icon_color(Color::Muted)
513 .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
514 .into_any_element(),
515 )
516 }
517 })
518 .child(
519 Button::new("view-models", "All Models")
520 .style(ButtonStyle::Subtle)
521 .icon(IconName::ExternalLink)
522 .icon_size(IconSize::XSmall)
523 .icon_color(Color::Muted)
524 .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
525 ),
526 )
527 .child(if is_authenticated {
528 // This is only a button to ensure the spacing is correct
529 // it should stay disabled
530 ButtonLike::new("connected")
531 .disabled(true)
532 // Since this won't ever be clickable, we can use the arrow cursor
533 .cursor_style(gpui::CursorStyle::Arrow)
534 .child(
535 h_flex()
536 .gap_2()
537 .child(Indicator::dot().color(Color::Success))
538 .child(Label::new("Connected"))
539 .into_any_element(),
540 )
541 .into_any_element()
542 } else {
543 Button::new("retry_ollama_models", "Connect")
544 .icon_position(IconPosition::Start)
545 .icon(IconName::ArrowCircle)
546 .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
547 .into_any_element()
548 }),
549 )
550 .into_any()
551 }
552 }
553}