Skip to content

Commit 995dbfe

Browse files
committed
fix: added support for streaming
1 parent ed7d8b2 commit 995dbfe

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

docs/UI_GENERATION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1183,3 +1183,4 @@ function ButtonComponent({ button }) {
11831183

11841184
**Built with ❤️ for creating beautiful AI-generated UIs**
11851185

1186+

llm/manager.go

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,3 +579,18 @@ func (m *LLMManager) IsStarted() bool {
579579

580580
return m.started
581581
}
582+
583+
// SupportsStreaming checks if a provider supports streaming.
584+
func (m *LLMManager) SupportsStreaming(providerName string) bool {
585+
m.mu.RLock()
586+
defer m.mu.RUnlock()
587+
588+
provider, exists := m.providers[providerName]
589+
if !exists {
590+
return false
591+
}
592+
593+
// Check if provider implements StreamingProvider interface
594+
_, ok := provider.(StreamingProvider)
595+
return ok
596+
}

llm/manager_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package llm
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
logger "github.com/xraph/go-utils/log"
8+
"github.com/xraph/go-utils/metrics"
9+
)
10+
11+
// mockStreamingProvider is a mock provider that implements StreamingProvider interface
12+
type mockStreamingProvider struct {
13+
name string
14+
}
15+
16+
func (m *mockStreamingProvider) Name() string { return m.name }
17+
func (m *mockStreamingProvider) Models() []string { return []string{"mock-model"} }
18+
func (m *mockStreamingProvider) Chat(ctx context.Context, request ChatRequest) (ChatResponse, error) {
19+
return ChatResponse{}, nil
20+
}
21+
func (m *mockStreamingProvider) Complete(ctx context.Context, request CompletionRequest) (CompletionResponse, error) {
22+
return CompletionResponse{}, nil
23+
}
24+
func (m *mockStreamingProvider) Embed(ctx context.Context, request EmbeddingRequest) (EmbeddingResponse, error) {
25+
return EmbeddingResponse{}, nil
26+
}
27+
func (m *mockStreamingProvider) GetUsage() LLMUsage { return LLMUsage{} }
28+
func (m *mockStreamingProvider) HealthCheck(ctx context.Context) error { return nil }
29+
func (m *mockStreamingProvider) ChatStream(ctx context.Context, request ChatRequest, handler func(ChatStreamEvent) error) error {
30+
return nil
31+
}
32+
33+
// mockBasicProvider is a mock provider that only implements LLMProvider interface (no streaming)
34+
type mockBasicProvider struct {
35+
name string
36+
}
37+
38+
func (m *mockBasicProvider) Name() string { return m.name }
39+
func (m *mockBasicProvider) Models() []string { return []string{"mock-model"} }
40+
func (m *mockBasicProvider) Chat(ctx context.Context, request ChatRequest) (ChatResponse, error) {
41+
return ChatResponse{}, nil
42+
}
43+
func (m *mockBasicProvider) Complete(ctx context.Context, request CompletionRequest) (CompletionResponse, error) {
44+
return CompletionResponse{}, nil
45+
}
46+
func (m *mockBasicProvider) Embed(ctx context.Context, request EmbeddingRequest) (EmbeddingResponse, error) {
47+
return EmbeddingResponse{}, nil
48+
}
49+
func (m *mockBasicProvider) GetUsage() LLMUsage { return LLMUsage{} }
50+
func (m *mockBasicProvider) HealthCheck(ctx context.Context) error { return nil }
51+
52+
func TestLLMManager_SupportsStreaming(t *testing.T) {
53+
tests := []struct {
54+
name string
55+
providerName string
56+
provider LLMProvider
57+
expectedSupports bool
58+
}{
59+
{
60+
name: "streaming provider",
61+
providerName: "streaming-provider",
62+
provider: &mockStreamingProvider{name: "streaming-provider"},
63+
expectedSupports: true,
64+
},
65+
{
66+
name: "non-streaming provider",
67+
providerName: "basic-provider",
68+
provider: &mockBasicProvider{name: "basic-provider"},
69+
expectedSupports: false,
70+
},
71+
{
72+
name: "non-existent provider",
73+
providerName: "non-existent",
74+
provider: nil,
75+
expectedSupports: false,
76+
},
77+
}
78+
79+
for _, tt := range tests {
80+
t.Run(tt.name, func(t *testing.T) {
81+
// Create manager
82+
manager, err := NewLLMManager(LLMManagerConfig{
83+
Logger: logger.NewTestLogger(),
84+
Metrics: metrics.NewMockMetrics(),
85+
})
86+
if err != nil {
87+
t.Fatalf("Failed to create LLM manager: %v", err)
88+
}
89+
90+
// Register provider if it exists
91+
if tt.provider != nil {
92+
if err := manager.RegisterProvider(tt.provider); err != nil {
93+
t.Fatalf("Failed to register provider: %v", err)
94+
}
95+
}
96+
97+
// Test SupportsStreaming
98+
supports := manager.SupportsStreaming(tt.providerName)
99+
if supports != tt.expectedSupports {
100+
t.Errorf("SupportsStreaming(%q) = %v, want %v", tt.providerName, supports, tt.expectedSupports)
101+
}
102+
})
103+
}
104+
}
105+
106+
func TestLLMManager_SupportsStreaming_ThreadSafety(t *testing.T) {
107+
manager, err := NewLLMManager(LLMManagerConfig{
108+
Logger: logger.NewTestLogger(),
109+
Metrics: metrics.NewMockMetrics(),
110+
})
111+
if err != nil {
112+
t.Fatalf("Failed to create LLM manager: %v", err)
113+
}
114+
115+
providerName := "streaming-test"
116+
117+
// Register a streaming provider
118+
if err := manager.RegisterProvider(&mockStreamingProvider{name: providerName}); err != nil {
119+
t.Fatalf("Failed to register provider: %v", err)
120+
}
121+
122+
// Test concurrent access
123+
done := make(chan bool)
124+
for i := 0; i < 10; i++ {
125+
go func() {
126+
for j := 0; j < 100; j++ {
127+
manager.SupportsStreaming(providerName)
128+
}
129+
done <- true
130+
}()
131+
}
132+
133+
// Wait for all goroutines to complete
134+
for i := 0; i < 10; i++ {
135+
<-done
136+
}
137+
}

0 commit comments

Comments
 (0)