diff --git a/go.mod b/go.mod index 42f2cb26373d3ddf331b0995f96b4143279741d9..e7793078a5914995552d11422d4d0b7ba503e80b 100644 --- a/go.mod +++ b/go.mod @@ -49,11 +49,16 @@ require ( ) require ( + cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect github.com/bahlo/generic-list-go v0.2.0 // indirect github.com/buger/jsonparser v1.1.1 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/wk8/go-ordered-map/v2 v2.1.8 // indirect + go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect golang.org/x/exp v0.0.0-20250305212735-054e65f0b394 // indirect + golang.org/x/oauth2 v0.25.0 // indirect + golang.org/x/time v0.8.0 // indirect + google.golang.org/api v0.211.0 // indirect ) require ( diff --git a/go.sum b/go.sum index 75e178374abd7f88b2f779ffbef342d822a84431..b7990e90c675695b352a9c85090d694af0fcac66 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,8 @@ cloud.google.com/go v0.116.0 h1:B3fRrSDkLRt5qSHWe40ERJvhvnQwdZiHu0bJOpldweE= cloud.google.com/go v0.116.0/go.mod h1:cEPSRWPzZEswwdr9BxE6ChEn01dWlTaF05LiC2Xs70U= cloud.google.com/go/auth v0.13.0 h1:8Fu8TZy167JkW8Tj3q7dIkr2v4cndv41ouecJx0PAHs= cloud.google.com/go/auth v0.13.0/go.mod h1:COOjD9gwfKNKz+IIduatIhYJQIc0mG3H102r/EMxX6Q= +cloud.google.com/go/auth/oauth2adapt v0.2.6 h1:V6a6XDu2lTwPZWOawrAa9HUK+DB2zfJyTuciBG5hFkU= +cloud.google.com/go/auth/oauth2adapt v0.2.6/go.mod h1:AlmsELtlEBnaNTL7jCj8VQFLy6mbZv0s4Q7NGBeQ5E8= cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= github.com/Azure/azure-sdk-for-go/sdk/azcore v1.17.0 h1:g0EZJwz7xkXQiZAI5xi9f3WWFYBlX1CPTrR+NDToRkQ= @@ -290,6 +292,8 @@ github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0= github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 h1:r6I7RJCN86bpD/FQwedZ0vSixDpwuWREjW9oRMsmqDc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0/go.mod h1:B9yO6b04uB80CzjedvewuqDhxJxi11s7/GtiGa8bAjI= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 h1:TT4fX+nBOA/+LUkobKGW1ydGcn+G3vRw9+g5HwCphpk= go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0/go.mod h1:L7UH0GbB0p47T4Rri3uHjbpCFYrVrwc1I25QhNPiGK8= go.opentelemetry.io/otel v1.35.0 h1:xKWKPxrxB6OtMCbmMY021CqC45J+3Onta9MqjhnusiQ= @@ -328,6 +332,8 @@ golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -367,11 +373,15 @@ golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= +golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/api v0.211.0 h1:IUpLjq09jxBSV1lACO33CGY3jsRcbctfGzhj+ZSE/Bg= +google.golang.org/api v0.211.0/go.mod h1:XOloB4MXFH4UTlQSGuNUxw0UT74qdENK8d6JNsXKLi0= google.golang.org/genai v1.3.0 h1:tXhPJF30skOjnnDY7ZnjK3q7IKy4PuAlEA0fk7uEaEI= google.golang.org/genai v1.3.0/go.mod h1:TyfOKRz/QyCaj6f/ZDt505x+YreXnY40l2I6k8TvgqY= google.golang.org/genproto/googleapis/rpc v0.0.0-20250324211829-b45e905df463 h1:e0AIkUUhxyBKh6ssZNrAMeqhA7RKUj42346d1y02i2g= diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index 400867a122cd296036f8839740d5c53106a900ee..3f31335dabb4802b5811947be8c5c0b9b8a03fc4 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -15,6 +15,7 @@ import ( "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/bedrock" "github.com/anthropics/anthropic-sdk-go/option" + "github.com/anthropics/anthropic-sdk-go/vertex" "github.com/charmbracelet/catwalk/pkg/catwalk" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/llm/tools" @@ -26,21 +27,30 @@ var contextLimitRegex = regexp.MustCompile(`input length and ` + "`max_tokens`" type anthropicClient struct { providerOptions providerClientOptions - useBedrock bool + tp AnthropicClientType client anthropic.Client adjustedMaxTokens int // Used when context limit is hit } type AnthropicClient ProviderClient -func newAnthropicClient(opts providerClientOptions, useBedrock bool) AnthropicClient { +type AnthropicClientType string + +const ( + AnthropicClientTypeNormal AnthropicClientType = "normal" + AnthropicClientTypeBedrock AnthropicClientType = "bedrock" + AnthropicClientTypeVertex AnthropicClientType = "vertex" +) + +func newAnthropicClient(opts providerClientOptions, tp AnthropicClientType) AnthropicClient { return &anthropicClient{ providerOptions: opts, - client: createAnthropicClient(opts, useBedrock), + tp: tp, + client: createAnthropicClient(opts, tp), } } -func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropic.Client { +func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) anthropic.Client { anthropicClientOptions := []option.RequestOption{} // Check if Authorization header is provided in extra headers @@ -67,8 +77,13 @@ func createAnthropicClient(opts providerClientOptions, useBedrock bool) anthropi } else if hasBearerAuth { slog.Debug("Skipping X-Api-Key header because Authorization header is provided") } - if useBedrock { + switch tp { + case AnthropicClientTypeBedrock: anthropicClientOptions = append(anthropicClientOptions, bedrock.WithLoadDefaultConfig(context.Background())) + case AnthropicClientTypeVertex: + project := opts.extraParams["project"] + location := opts.extraParams["location"] + anthropicClientOptions = append(anthropicClientOptions, vertex.WithGoogleAuth(context.Background(), location, project)) } for key, header := range opts.extraHeaders { anthropicClientOptions = append(anthropicClientOptions, option.WithHeaderAdd(key, header)) @@ -478,7 +493,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } - a.client = createAnthropicClient(a.providerOptions, a.useBedrock) + a.client = createAnthropicClient(a.providerOptions, a.tp) return true, 0, nil } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 8b5b21c36a390e80843504c7c9f6c257156f6379..526d11b5597859853be9314ed618748e3ae40f38 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -52,7 +52,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { opts.disableCache = true // Disable cache for Bedrock return &bedrockClient{ providerOptions: opts, - childProvider: newAnthropicClient(anthropicOpts, true), + childProvider: newAnthropicClient(anthropicOpts, AnthropicClientTypeBedrock), } } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 14b90fa95128efe6bb1f629ca675d6bc3e5fa646..6376561aa437c0dfcd4abeb8f7ed2fd2b182e936 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -173,7 +173,7 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi case catwalk.TypeAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, - client: newAnthropicClient(clientOptions, false), + client: newAnthropicClient(clientOptions, AnthropicClientTypeNormal), }, nil case catwalk.TypeOpenAI: return &baseProvider[OpenAIClient]{ diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index 49a28084ab2096522399c43ea2824c2b12063244..cbc86d8b7428639ea89ad49771ed4515d18adc07 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -3,6 +3,7 @@ package provider import ( "context" "log/slog" + "strings" "google.golang.org/genai" ) @@ -22,6 +23,10 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient { return nil } + model := opts.model(opts.modelType) + if strings.Contains(model.ID, "anthropic") || strings.Contains(model.ID, "claude-sonnet") { + return newAnthropicClient(opts, AnthropicClientTypeVertex) + } return &geminiClient{ providerOptions: opts, client: client,