1package provider
 2
 3import (
 4	"context"
 5	"errors"
 6	"fmt"
 7	"strings"
 8
 9	"github.com/charmbracelet/catwalk/pkg/catwalk"
10	"github.com/charmbracelet/crush/internal/config"
11	"github.com/charmbracelet/crush/internal/llm/tools"
12	"github.com/charmbracelet/crush/internal/message"
13)
14
15type bedrockClient struct {
16	providerOptions providerClientOptions
17	childProvider   ProviderClient
18}
19
20type BedrockClient ProviderClient
21
22func newBedrockClient(opts providerClientOptions) BedrockClient {
23	// Get AWS region from environment
24	region := opts.extraParams["region"]
25	if region == "" {
26		region = "us-east-1" // default region
27	}
28	if len(region) < 2 {
29		return &bedrockClient{
30			providerOptions: opts,
31			childProvider:   nil, // Will cause an error when used
32		}
33	}
34
35	opts.model = func(modelType config.SelectedModelType) catwalk.Model {
36		model := config.Get().GetModelByType(modelType)
37
38		// Prefix the model name with region
39		regionPrefix := region[:2]
40		modelName := model.ID
41		model.ID = fmt.Sprintf("%s.%s", regionPrefix, modelName)
42		return *model
43	}
44
45	model := opts.model(opts.modelType)
46
47	// Determine which provider to use based on the model
48	if strings.Contains(string(model.ID), "anthropic") {
49		// Create Anthropic client with Bedrock configuration
50		anthropicOpts := opts
51		// TODO: later find a way to check if the AWS account has caching enabled
52		opts.disableCache = true // Disable cache for Bedrock
53		return &bedrockClient{
54			providerOptions: opts,
55			childProvider:   newAnthropicClient(anthropicOpts, AnthropicClientTypeBedrock),
56		}
57	}
58
59	// Return client with nil childProvider if model is not supported
60	// This will cause an error when used
61	return &bedrockClient{
62		providerOptions: opts,
63		childProvider:   nil,
64	}
65}
66
67func (b *bedrockClient) send(ctx context.Context, messages []message.Message, tools []tools.BaseTool) (*ProviderResponse, error) {
68	if b.childProvider == nil {
69		return nil, errors.New("unsupported model for bedrock provider")
70	}
71	return b.childProvider.send(ctx, messages, tools)
72}
73
74func (b *bedrockClient) stream(ctx context.Context, messages []message.Message, tools []tools.BaseTool) <-chan ProviderEvent {
75	eventChan := make(chan ProviderEvent)
76
77	if b.childProvider == nil {
78		go func() {
79			eventChan <- ProviderEvent{
80				Type:  EventError,
81				Error: errors.New("unsupported model for bedrock provider"),
82			}
83			close(eventChan)
84		}()
85		return eventChan
86	}
87
88	return b.childProvider.stream(ctx, messages, tools)
89}
90
91func (b *bedrockClient) Model() catwalk.Model {
92	return b.providerOptions.model(b.providerOptions.modelType)
93}