1import { getModel, type Model } from "@mariozechner/pi-ai";
2import {
3 type CustomModelConfig,
4 type RumiloConfig,
5} from "../config/schema.js";
6import { ConfigError } from "../util/errors.js";
7
8export function resolveModel(
9 modelString: string,
10 config: RumiloConfig,
11): Model<any> {
12 const [provider, modelName] = modelString.split(":");
13
14 if (!provider || !modelName) {
15 throw new ConfigError("Model must be in provider:model format");
16 }
17
18 // Handle custom models
19 if (provider === "custom") {
20 return resolveCustomModel(modelName, config);
21 }
22
23 // Handle built-in providers
24 return getModel(provider as any, modelName);
25}
26
27function resolveCustomModel(modelName: string, config: RumiloConfig): Model<any> {
28 if (!config.custom_models) {
29 throw new ConfigError(
30 `No custom models configured. Use 'custom:' prefix only with custom model definitions in config.`,
31 );
32 }
33
34 const customConfig = config.custom_models[modelName];
35 if (!customConfig) {
36 const available = Object.keys(config.custom_models).join(", ");
37 throw new ConfigError(
38 `Custom model '${modelName}' not found. Available custom models: ${available}`,
39 );
40 }
41
42 return buildCustomModel(customConfig);
43}
44
45function buildCustomModel(config: CustomModelConfig): Model<any> {
46 const api = config.api as any;
47
48 const cost: any = {
49 input: config.cost.input,
50 output: config.cost.output,
51 };
52
53 if (config.cost.cache_read !== undefined) {
54 cost.cacheRead = config.cost.cache_read;
55 }
56
57 if (config.cost.cache_write !== undefined) {
58 cost.cacheWrite = config.cost.cache_write;
59 }
60
61 const model: any = {
62 id: config.id,
63 name: config.name,
64 api,
65 provider: config.provider as any,
66 baseUrl: config.base_url,
67 reasoning: config.reasoning,
68 input: config.input,
69 cost,
70 contextWindow: config.context_window,
71 maxTokens: config.max_tokens,
72 };
73
74 if (config.headers) {
75 model.headers = config.headers;
76 }
77
78 if (config.compat) {
79 model.compat = convertCompatConfig(config.compat);
80 }
81
82 return model;
83}
84
85function convertCompatConfig(
86 compat: CustomModelConfig["compat"],
87): any {
88 if (!compat) {
89 throw new Error("Compat config is expected to be defined");
90 }
91
92 return {
93 supportsStore: compat.supports_store,
94 supportsDeveloperRole: compat.supports_developer_role,
95 supportsReasoningEffort: compat.supports_reasoning_effort,
96 supportsUsageInStreaming: compat.supports_usage_in_streaming,
97 maxTokensField: compat.max_tokens_field,
98 requiresToolResultName: compat.requires_tool_result_name,
99 requiresAssistantAfterToolResult: compat.requires_assistant_after_tool_result,
100 requiresThinkingAsText: compat.requires_thinking_as_text,
101 requiresMistralToolIds: compat.requires_mistral_tool_ids,
102 thinkingFormat: compat.thinking_format,
103 };
104}