-
Notifications
You must be signed in to change notification settings - Fork 131
Expand file tree
/
Copy pathsearch_providers.go
More file actions
398 lines (354 loc) · 19.2 KB
/
search_providers.go
File metadata and controls
398 lines (354 loc) · 19.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package tools
import (
"context"
"encoding/json"
"fmt"
"net/http"
"path"
"sort"
"strings"
"github.com/hashicorp/terraform-mcp-server/pkg/client"
"github.com/hashicorp/terraform-mcp-server/pkg/utils"
log "github.com/sirupsen/logrus"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)
// sendRegistryCall is a package-level variable so tests can override registry calls.
var sendRegistryCall = client.SendRegistryCall
// tierOrder defines sorting priority for provider tiers.
var tierOrder = map[string]int{"official": 0, "partner": 1, "community": 2}
type providerMatch struct {
Namespace string
Name string
Tier string
DocMatch []client.ProviderDoc
}
// sortMatchesByTier sorts the matches slice in-place by tier using tierOrder.
func sortMatchesByTier(matches []providerMatch) {
sort.SliceStable(matches, func(i, j int) bool {
return tierOrder[strings.ToLower(matches[i].Tier)] < tierOrder[strings.ToLower(matches[j].Tier)]
})
}
// ResolveProviderDocID creates a tool to get provider details from registry.
func ResolveProviderDocID(logger *log.Logger) server.ServerTool {
return server.ServerTool{
Tool: mcp.NewTool("search_providers",
mcp.WithDescription(`This tool retrieves a list of potential documents based on the service_slug and provider_data_type provided.
You MUST call this function before 'get_provider_details' to obtain a valid tfprovider-compatible provider_doc_id.
Use the most relevant single word as the search query for service_slug, if unsure about the service_slug, use the provider_name for its value.
When selecting the best match, consider the following:
- Title similarity to the query
- Category relevance
Return the selected provider_doc_id and explain your choice.
If there are multiple good matches, mention this but proceed with the most relevant one.`),
mcp.WithTitleAnnotation("Identify the most relevant provider document ID for a Terraform service"),
mcp.WithOpenWorldHintAnnotation(true),
mcp.WithReadOnlyHintAnnotation(true),
mcp.WithDestructiveHintAnnotation(false),
mcp.WithString("provider_name",
mcp.Required(),
mcp.Description("The name of the Terraform provider to perform the read or deployment operation"),
),
mcp.WithString("provider_namespace",
mcp.Required(),
mcp.Description("The publisher of the Terraform provider, typically the name of the company, or their GitHub organization name that created the provider"),
),
mcp.WithString("service_slug",
mcp.Required(),
mcp.Description("The slug of the service you want to deploy or read using the Terraform provider, prefer using a single word, use underscores for multiple words and if unsure about the service_slug, use the provider_name for its value"),
),
mcp.WithString("provider_data_type",
mcp.Description("The type of the document to retrieve, for general information use 'guides', for deploying resources use 'resources', for reading pre-deployed resources use 'data-sources', for functions use 'functions', and for overview of the provider use 'overview'"),
mcp.Enum("resources", "data-sources", "functions", "guides", "overview"),
mcp.DefaultString("resources"),
),
mcp.WithString("provider_version",
mcp.Description("The version of the Terraform provider to retrieve in the format 'x.y.z', or 'latest' to get the latest version")),
),
Handler: func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
return resolveProviderDocIDHandler(ctx, request, logger)
},
}
}
func resolveProviderDocIDHandler(ctx context.Context, request mcp.CallToolRequest, logger *log.Logger) (*mcp.CallToolResult, error) {
// For typical provider and namespace hallucinations
defaultErrorGuide := "please check the provider name, provider namespace or the provider version you're looking for, perhaps the provider is published under a different namespace or company name"
// Get a simple http client to access the public Terraform registry from context
httpClient, err := client.GetHttpClientFromContext(ctx, logger)
if err != nil {
logger.WithError(err).Error("failed to get http client for public Terraform registry")
return mcp.NewToolResultError(fmt.Sprintf("failed to get http client for public Terraform registry: %v", err)), nil
}
providerDetail, err := resolveProviderDetails(request, httpClient, defaultErrorGuide, logger)
if err != nil {
return nil, err
}
serviceSlug, err := request.RequireString("service_slug")
if err != nil {
return nil, utils.LogAndReturnError(logger, "required input: service_slug is required", err)
}
if serviceSlug == "" {
return nil, utils.LogAndReturnError(logger, "required input: service_slug cannot be empty", nil)
}
serviceSlug = strings.ToLower(serviceSlug)
providerDataType := request.GetString("provider_data_type", "resources")
providerDetail.ProviderDataType = providerDataType
// Check if we need to use v2 API for guides, functions, or overview
if utils.IsV2ProviderDataType(providerDetail.ProviderDataType) {
content, err := providerDetailsV2(httpClient, providerDetail, logger)
if err != nil {
errMessage := fmt.Sprintf(`finding %s documentation for provider '%s' in the '%s' namespace, %s`, providerDetail.ProviderDataType, providerDetail.ProviderName, providerDetail.ProviderNamespace, defaultErrorGuide)
return nil, utils.LogAndReturnError(logger, errMessage, err)
}
fullContent := fmt.Sprintf("# %s provider docs\n\n%s", providerDetail.ProviderName, content)
return mcp.NewToolResultText(fullContent), nil
}
// Delegate to extracted helper so it can be unit-tested.
result, err := searchProvidersDocs(httpClient, providerDetail, serviceSlug, defaultErrorGuide, logger)
if err != nil {
return nil, err
}
return mcp.NewToolResultText(result), nil
}
// searchProvidersDocs contains the core provider-search and prioritization logic.
// It returns the textual result (same content as the tool would return) for easier unit testing.
func searchProvidersDocs(httpClient *http.Client, providerDetail client.ProviderDetail, serviceSlug string, defaultErrorGuide string, logger *log.Logger) (string, error) {
// Enhanced: Search all providers matching the name and prioritize by tier
searchUri := "providers?filter[name]=" + providerDetail.ProviderName
searchResp, err := sendRegistryCall(httpClient, "GET", searchUri, logger, "v2")
if err != nil {
return "", utils.LogAndReturnError(logger, "error searching providers in registry", err)
}
var providerList client.ProviderList
if err := json.Unmarshal(searchResp, &providerList); err != nil {
return "", utils.LogAndReturnError(logger, "unmarshalling provider list", err)
}
// If the registry search didn't return any providers, fall back to fetching
// the single provider directly (preserves previous behavior for cases where
// provider namespace defaults to hashicorp and the search endpoint may not
// return results matching our filter).
logger.Infof("provider search returned %d providers for name '%s'", len(providerList.Data), providerDetail.ProviderName)
if len(providerList.Data) == 0 {
logger.Infof("falling back to single-provider fetch for %s/%s@%s", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
uri := path.Join("providers", providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
response, err := sendRegistryCall(httpClient, "GET", uri, logger)
logger.Debugf("provider docs fetch URI: %s", uri)
if err != nil {
return "", utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerDetail.ProviderName, providerDetail.ProviderVersion, providerDetail.ProviderNamespace, defaultErrorGuide), nil)
}
var providerDocs client.ProviderDocs
if err := json.Unmarshal(response, &providerDocs); err != nil {
return "", utils.LogAndReturnError(logger, "unmarshalling provider docs", err)
}
logger.Infof("provider docs returned %d docs for %s/%s@%s", len(providerDocs.Docs), providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion)
var builder strings.Builder
builder.WriteString(fmt.Sprintf("Available Documentation (top matches) for %s in Terraform provider %s/%s version: %s\n\n", providerDetail.ProviderDataType, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion))
builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n")
builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n")
contentAvailable := false
for _, doc := range providerDocs.Docs {
if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType {
cs, err := utils.ContainsSlug(doc.Slug, serviceSlug)
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", providerDetail.ProviderName, doc.Slug), serviceSlug)
if (cs || cs_pn) && err == nil && err_pn == nil {
contentAvailable = true
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet))
}
}
}
if !contentAvailable {
errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug)
return "", utils.LogAndReturnError(logger, errMessage, err)
}
return builder.String(), nil
}
var matches []providerMatch
for _, pdata := range providerList.Data {
namespace := pdata.Attributes.Namespace
name := pdata.Attributes.Name
tier := pdata.Attributes.Tier
logger.Debugf("search provider entry: namespace=%s name=%s tier=%s", namespace, name, tier)
// Get docs for this provider. Try the requested version first; if that
// fails (for example the version doesn't exist in this namespace), try
// to resolve the latest version for that namespace/name and retry.
uri := path.Join("providers", namespace, name, providerDetail.ProviderVersion)
response, err := sendRegistryCall(httpClient, "GET", uri, logger)
if err != nil {
// Attempt to fetch the latest provider version for this namespace/name
latestVer, verErr := client.GetLatestProviderVersion(httpClient, namespace, name, logger)
if verErr != nil {
logger.Debugf("skipping provider %s/%s: error fetching docs: %v (also failed to get latest version: %v)", namespace, name, err, verErr)
continue // skip providers we can't fetch
}
uri = path.Join("providers", namespace, name, latestVer)
response, err = sendRegistryCall(httpClient, "GET", uri, logger)
if err != nil {
logger.Debugf("skipping provider %s/%s: error fetching docs with latest version %s: %v", namespace, name, latestVer, err)
continue
}
}
var providerDocs client.ProviderDocs
if err := json.Unmarshal(response, &providerDocs); err != nil {
logger.Debugf("skipping provider %s/%s: error unmarshalling docs: %v", namespace, name, err)
continue
}
logger.Debugf("fetched %d docs for provider %s/%s", len(providerDocs.Docs), namespace, name)
var docMatches []client.ProviderDoc
for _, doc := range providerDocs.Docs {
logger.Tracef("considering doc slug=%s title=%s category=%s language=%s", doc.Slug, doc.Title, doc.Category, doc.Language)
if doc.Language == "hcl" && doc.Category == providerDetail.ProviderDataType {
cs, err := utils.ContainsSlug(doc.Slug, serviceSlug)
cs_pn, err_pn := utils.ContainsSlug(fmt.Sprintf("%s_%s", name, doc.Slug), serviceSlug)
if (cs || cs_pn) && err == nil && err_pn == nil {
logger.Debugf("matched doc %s for provider %s/%s (slug=%s)", doc.ID, namespace, name, doc.Slug)
docMatches = append(docMatches, doc)
}
}
}
if len(docMatches) > 0 {
matches = append(matches, providerMatch{
Namespace: namespace,
Name: name,
Tier: tier,
DocMatch: docMatches,
})
}
}
if len(matches) == 0 {
errMessage := fmt.Sprintf(`finding documentation for service_slug %s, provide a more relevant service_slug if unsure, use the provider_name for its value`, serviceSlug)
return "", utils.LogAndReturnError(logger, errMessage, err)
}
// Sort matches by tier
sortMatchesByTier(matches)
var builder strings.Builder
builder.WriteString("Available Documentation (prioritized by provider tier)\n\n")
builder.WriteString("Tier order: official > partner > community\n\n")
for _, match := range matches {
builder.WriteString(fmt.Sprintf("Provider: %s/%s (Tier: %s)\n", match.Namespace, match.Name, match.Tier))
for _, doc := range match.DocMatch {
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Title, doc.Category, descriptionSnippet))
}
builder.WriteString("\n")
}
return builder.String(), nil
}
func resolveProviderDetails(request mcp.CallToolRequest, httpClient *http.Client, defaultErrorGuide string, logger *log.Logger) (client.ProviderDetail, error) {
providerDetail := client.ProviderDetail{}
providerName := request.GetString("provider_name", "")
if providerName == "" {
return providerDetail, fmt.Errorf("provider_name is required and must be a string")
}
providerName = strings.ToLower(providerName)
providerNamespace := request.GetString("provider_namespace", "")
if providerNamespace == "" {
logger.Debugf(`Error getting latest provider version in "%s" namespace, trying the hashicorp namespace`, providerNamespace)
providerNamespace = "hashicorp"
}
providerNamespace = strings.ToLower(providerNamespace)
providerVersion := request.GetString("provider_version", "latest")
providerVersion = strings.ToLower(providerVersion)
providerDataType := request.GetString("provider_data_type", "resources")
providerDataType = strings.ToLower(providerDataType)
var err error
providerVersionValue := ""
if utils.IsValidProviderVersionFormat(providerVersion) {
providerVersionValue = providerVersion
} else {
providerVersionValue, err = client.GetLatestProviderVersion(httpClient, providerNamespace, providerName, logger)
if err != nil {
providerVersionValue = ""
logger.Debugf("Error getting latest provider version in %s namespace: %v", providerNamespace, err)
}
}
// If the provider version doesn't exist, try the hashicorp namespace
if providerVersionValue == "" {
tryProviderNamespace := "hashicorp"
providerVersionValue, err = client.GetLatestProviderVersion(httpClient, tryProviderNamespace, providerName, logger)
if err != nil {
// Just so we don't print the same namespace twice if they are the same
if providerNamespace != tryProviderNamespace {
tryProviderNamespace = fmt.Sprintf(`"%s" or the "%s"`, providerNamespace, tryProviderNamespace)
}
return providerDetail, utils.LogAndReturnError(logger, fmt.Sprintf(`getting the "%s" provider, with version "%s" in the %s namespace, %s`, providerName, providerVersion, tryProviderNamespace, defaultErrorGuide), nil)
}
providerNamespace = tryProviderNamespace // Update the namespace to hashicorp, if successful
}
providerDataTypeValue := ""
if utils.IsValidProviderDataType(providerDataType) {
providerDataTypeValue = providerDataType
}
providerDetail.ProviderName = providerName
providerDetail.ProviderNamespace = providerNamespace
providerDetail.ProviderVersion = providerVersionValue
providerDetail.ProviderDataType = providerDataTypeValue
return providerDetail, nil
}
// providerDetailsV2 retrieves a list of documentation items for a specific provider category using v2 API with support for pagination using page numbers
func providerDetailsV2(httpClient *http.Client, providerDetail client.ProviderDetail, logger *log.Logger) (string, error) {
providerVersionID, err := client.GetProviderVersionID(httpClient, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion, logger)
if err != nil {
return "", utils.LogAndReturnError(logger, "getting provider version ID", err)
}
category := providerDetail.ProviderDataType
if category == "overview" {
return client.GetProviderOverviewDocs(httpClient, providerVersionID, logger)
}
uriPrefix := fmt.Sprintf("provider-docs?filter[provider-version]=%s&filter[category]=%s&filter[language]=hcl", providerVersionID, category)
docs, err := client.SendPaginatedRegistryCall(httpClient, uriPrefix, logger)
if err != nil {
return "", utils.LogAndReturnError(logger, "getting provider documentation", err)
}
if len(docs) == 0 {
return "", fmt.Errorf("no %s documentation found for provider version %s", category, providerVersionID)
}
var builder strings.Builder
builder.WriteString(fmt.Sprintf("Available Documentation (top matches) for %s in Terraform provider %s/%s version: %s\n\n", providerDetail.ProviderDataType, providerDetail.ProviderNamespace, providerDetail.ProviderName, providerDetail.ProviderVersion))
builder.WriteString("Each result includes:\n- providerDocID: tfprovider-compatible identifier\n- Title: Service or resource name\n- Category: Type of document\n- Description: Brief summary of the document\n")
builder.WriteString("For best results, select libraries based on the service_slug match and category of information requested.\n\n---\n\n")
for _, doc := range docs {
descriptionSnippet, err := getContentSnippet(httpClient, doc.ID, logger)
if err != nil {
logger.Warnf("Error fetching content snippet for provider doc ID: %s: %v", doc.ID, err)
}
builder.WriteString(fmt.Sprintf("- providerDocID: %s\n- Title: %s\n- Category: %s\n- Description: %s\n---\n", doc.ID, doc.Attributes.Title, doc.Attributes.Category, descriptionSnippet))
}
return builder.String(), nil
}
func getContentSnippet(httpClient *http.Client, docID string, logger *log.Logger) (string, error) {
docContent, err := client.SendRegistryCall(httpClient, "GET", fmt.Sprintf("provider-docs/%s", docID), logger, "v2")
if err != nil {
return "", utils.LogAndReturnError(logger, fmt.Sprintf("fetching provider-docs/%s within getContentSnippet", docID), err)
}
var docDescription client.ProviderResourceDetails
if err := json.Unmarshal(docContent, &docDescription); err != nil {
return "", utils.LogAndReturnError(logger, fmt.Sprintf("unmarshalling provider-docs/%s within getContentSnippet", docID), err)
}
content := docDescription.Data.Attributes.Content
// Try to extract description from markdown content
desc := ""
if start := strings.Index(content, "description: |-"); start != -1 {
if end := strings.Index(content[start:], "\n---"); end != -1 {
substring := content[start+len("description: |-") : start+end]
trimmed := strings.TrimSpace(substring)
desc = strings.ReplaceAll(trimmed, "\n", " ")
} else {
substring := content[start+len("description: |-"):]
trimmed := strings.TrimSpace(substring)
desc = strings.ReplaceAll(trimmed, "\n", " ")
}
}
if len(desc) > 300 {
return desc[:300] + "...", nil
}
return desc, nil
}