config.go

  1package config
  2
  3import (
  4	"fmt"
  5	"os"
  6	"strings"
  7
  8	"github.com/kujtimiihoxha/termai/internal/llm/models"
  9	"github.com/spf13/viper"
 10)
 11
 12type MCPType string
 13
 14const (
 15	MCPStdio MCPType = "stdio"
 16	MCPSse   MCPType = "sse"
 17)
 18
 19type MCPServer struct {
 20	Command string            `json:"command"`
 21	Env     []string          `json:"env"`
 22	Args    []string          `json:"args"`
 23	Type    MCPType           `json:"type"`
 24	URL     string            `json:"url"`
 25	Headers map[string]string `json:"headers"`
 26	// TODO: add permissions configuration
 27	// TODO: add the ability to specify the tools to import
 28}
 29
 30type Model struct {
 31	Coder          models.ModelID `json:"coder"`
 32	CoderMaxTokens int64          `json:"coderMaxTokens"`
 33
 34	Task          models.ModelID `json:"task"`
 35	TaskMaxTokens int64          `json:"taskMaxTokens"`
 36	// TODO: Maybe support multiple models for different purposes
 37}
 38
 39type Provider struct {
 40	APIKey  string `json:"apiKey"`
 41	Enabled bool   `json:"enabled"`
 42}
 43
 44type Data struct {
 45	Directory string `json:"directory"`
 46}
 47
 48type Log struct {
 49	Level string `json:"level"`
 50}
 51
 52type Config struct {
 53	Data       *Data                             `json:"data,omitempty"`
 54	Log        *Log                              `json:"log,omitempty"`
 55	MCPServers map[string]MCPServer              `json:"mcpServers,omitempty"`
 56	Providers  map[models.ModelProvider]Provider `json:"providers,omitempty"`
 57
 58	Model *Model `json:"model,omitempty"`
 59}
 60
 61var cfg *Config
 62
 63const (
 64	defaultDataDirectory = ".termai"
 65	defaultLogLevel      = "info"
 66	defaultMaxTokens     = int64(5000)
 67	termai               = "termai"
 68)
 69
 70func Load(debug bool) error {
 71	if cfg != nil {
 72		return nil
 73	}
 74
 75	viper.SetConfigName(fmt.Sprintf(".%s", termai))
 76	viper.SetConfigType("json")
 77	viper.AddConfigPath("$HOME")
 78	viper.AddConfigPath(fmt.Sprintf("$XDG_CONFIG_HOME/%s", termai))
 79	viper.SetEnvPrefix(strings.ToUpper(termai))
 80
 81	// Add defaults
 82	viper.SetDefault("data.directory", defaultDataDirectory)
 83	if debug {
 84		viper.Set("log.level", "debug")
 85	} else {
 86		viper.SetDefault("log.level", defaultLogLevel)
 87	}
 88
 89	defaultModelSet := false
 90	if os.Getenv("ANTHROPIC_API_KEY") != "" {
 91		viper.SetDefault("providers.anthropic.apiKey", os.Getenv("ANTHROPIC_API_KEY"))
 92		viper.SetDefault("providers.anthropic.enabled", true)
 93		viper.SetDefault("model.coder", models.Claude37Sonnet)
 94		viper.SetDefault("model.task", models.Claude37Sonnet)
 95		defaultModelSet = true
 96	}
 97	if os.Getenv("OPENAI_API_KEY") != "" {
 98		viper.SetDefault("providers.openai.apiKey", os.Getenv("OPENAI_API_KEY"))
 99		viper.SetDefault("providers.openai.enabled", true)
100		if !defaultModelSet {
101			viper.SetDefault("model.coder", models.GPT4o)
102			viper.SetDefault("model.task", models.GPT4o)
103			defaultModelSet = true
104		}
105	}
106	if os.Getenv("GEMINI_API_KEY") != "" {
107		viper.SetDefault("providers.gemini.apiKey", os.Getenv("GEMINI_API_KEY"))
108		viper.SetDefault("providers.gemini.enabled", true)
109		if !defaultModelSet {
110			viper.SetDefault("model.coder", models.GRMINI20Flash)
111			viper.SetDefault("model.task", models.GRMINI20Flash)
112			defaultModelSet = true
113		}
114	}
115	if os.Getenv("GROQ_API_KEY") != "" {
116		viper.SetDefault("providers.groq.apiKey", os.Getenv("GROQ_API_KEY"))
117		viper.SetDefault("providers.groq.enabled", true)
118		if !defaultModelSet {
119			viper.SetDefault("model.coder", models.QWENQwq)
120			viper.SetDefault("model.task", models.QWENQwq)
121			defaultModelSet = true
122		}
123	}
124	// TODO: add more providers
125	cfg = &Config{}
126
127	err := viper.ReadInConfig()
128	if err != nil {
129		if _, ok := err.(viper.ConfigFileNotFoundError); !ok {
130			return err
131		}
132	}
133	local := viper.New()
134	local.SetConfigName(fmt.Sprintf(".%s", termai))
135	local.SetConfigType("json")
136	local.AddConfigPath(".")
137	// load local config, this will override the global config
138	if err = local.ReadInConfig(); err == nil {
139		viper.MergeConfigMap(local.AllSettings())
140	}
141	viper.Unmarshal(cfg)
142
143	if cfg.Model != nil && cfg.Model.CoderMaxTokens <= 0 {
144		cfg.Model.CoderMaxTokens = defaultMaxTokens
145	}
146	if cfg.Model != nil && cfg.Model.TaskMaxTokens <= 0 {
147		cfg.Model.TaskMaxTokens = defaultMaxTokens
148	}
149
150	for _, v := range cfg.MCPServers {
151		if v.Type == "" {
152			v.Type = MCPStdio
153		}
154	}
155
156	workdir, err := os.Getwd()
157	if err != nil {
158		return err
159	}
160	viper.Set("wd", workdir)
161	return nil
162}
163
164func Get() *Config {
165	if cfg == nil {
166		err := Load(false)
167		if err != nil {
168			panic(err)
169		}
170	}
171	return cfg
172}
173
174func WorkingDirectory() string {
175	return viper.GetString("wd")
176}
177
178func Write() error {
179	return viper.WriteConfig()
180}