bedrock.go

 1package provider
 2
 3import (
 4	"context"
 5	"errors"
 6	"fmt"
 7	"strings"
 8
 9	"github.com/charmbracelet/crush/internal/config"
10	"github.com/charmbracelet/crush/internal/llm/tools"
11	"github.com/charmbracelet/crush/internal/message"
12)
13
14type bedrockClient struct {
15	providerOptions providerClientOptions
16	childProvider   ProviderClient
17}
18
19type BedrockClient ProviderClient
20
21func newBedrockClient(opts providerClientOptions) BedrockClient {
22	// Get AWS region from environment
23	region := opts.extraParams["region"]
24	if region == "" {
25		region = "us-east-1" // default region
26	}
27	if len(region) < 2 {
28		return &bedrockClient{
29			providerOptions: opts,
30			childProvider:   nil, // Will cause an error when used
31		}
32	}
33
34	opts.model = func(modelType config.ModelType) config.Model {
35		model := config.GetModel(modelType)
36
37		// Prefix the model name with region
38		regionPrefix := region[:2]
39		modelName := model.ID
40		model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
41		return model
42	}
43
44	model := opts.model(opts.modelType)
45
46	// Determine which provider to use based on the model
47	if strings.Contains(string(model.ID), "anthropic") {
48		// Create Anthropic client with Bedrock configuration
49		anthropicOpts := opts
50		// TODO: later find a way to check if the AWS account has caching enabled
51		opts.disableCache = true // Disable cache for Bedrock
52		return &bedrockClient{
53			providerOptions: opts,
54			childProvider:   newAnthropicClient(anthropicOpts, true),
55		}
56	}
57
58	// Return client with nil childProvider if model is not supported
59	// This will cause an error when used
60	return &bedrockClient{
61		providerOptions: opts,
62		childProvider:   nil,
63	}
64}
65
66func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
67	if b.childProvider == nil {
68		return nil, errors.New("unsupported model for bedrock provider")
69	}
70	return b.childProvider.send(ctx, messages, tools)
71}
72
73func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
74	eventChan := make(chan ProviderEvent)
75
76	if b.childProvider == nil {
77		go func() {
78			eventChan <- ProviderEvent{
79				Type:  EventError,
80				Error: errors.New("unsupported model for bedrock provider"),
81			}
82			close(eventChan)
83		}()
84		return eventChan
85	}
86
87	return b.childProvider.stream(ctx, messages, tools)
88}
89
90func (b *bedrockClient) Model() config.Model {
91	return b.providerOptions.model(b.providerOptions.modelType)
92}