diff --git a/internal/server/api.go b/internal/server/api.go index d3dd02f2..f7497a20 100644 --- a/internal/server/api.go +++ b/internal/server/api.go @@ -121,6 +121,9 @@ func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { parts := strings.Split(r.URL.String(), "/") kind := parts[len(parts)-1] + + a.outputRequestData(kind, data) + actual, err := decodeWrapper(kind, data) if err != nil { a.pushError(err) @@ -134,7 +137,6 @@ func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { if kind == "increment_metric" { // Let's just output the metrics data and stop - a.outputRequestData(kind, actual) return } @@ -144,7 +146,6 @@ func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) { } if !a.hasExpectations { - a.outputRequestData(kind, actual) return } @@ -178,12 +179,12 @@ func (a *API) assertExpectation(kind string, actual *model.UpdateWrapper) { } } -func (a *API) outputRequestData(kind string, actual *model.UpdateWrapper) { +func (a *API) outputRequestData(kind string, data []byte) { if a.writer != nil { // output the data received to stdout if err := json.NewEncoder(a.writer).Encode(map[string]any{ "type": kind, - "data": actual.Data, + "data": string(data), }); err != nil { // Fail so the user knows stdout is not working log.Panicln("Failed to write to stdout: ", err) diff --git a/internal/server/api_test.go b/internal/server/api_test.go index 9f4a3316..8b9d0071 100644 --- a/internal/server/api_test.go +++ b/internal/server/api_test.go @@ -1,8 +1,16 @@ package server import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "github.com/dependabot/cli/internal/model" "net/http" "net/http/httptest" + "strings" "testing" ) @@ -28,3 +36,63 @@ func TestAPI_ServeHTTP(t *testing.T) { } }) } + +func TestAPI_CreatePullRequest_ReplacesBinaryWithHash(t *testing.T) { + var stdout bytes.Buffer + + api := NewAPI(nil, &stdout) + defer api.Stop() + + content := base64.StdEncoding.EncodeToString([]byte("Hello, world!")) + hash := sha256.Sum256([]byte(content)) + expectedHashedContent := hex.EncodeToString(hash[:]) + + // Construct the request body for create_pull_request + createPullRequest := model.CreatePullRequest{ + UpdatedDependencyFiles: []model.DependencyFile{ + { + Content: content, + ContentEncoding: "base64", + }, + }, + } + var body bytes.Buffer + if err := json.NewEncoder(&body).Encode(model.UpdateWrapper{Data: createPullRequest}); err != nil { + t.Fatalf("failed to encode request body: %v", err) + } + + url := "http://127.0.0.1:" + // use the API's port + fmt.Sprintf("%d/create_pull_request", api.Port()) + req, err := http.NewRequest("POST", url, &body) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to send request: %v", err) + } + defer resp.Body.Close() + + if len(api.Errors) > 0 { + t.Fatalf("expected no errors, got %d errors: %v", len(api.Errors), api.Errors) + } + + // The API should have replaced the content with a SHA hash in a.Actual.Output + if len(api.Actual.Output) != 1 { + t.Fatalf("expected 1 output, got %d", len(api.Actual.Output)) + } + if api.Actual.Output[0].Type != "create_pull_request" { + t.Fatalf("expected output type 'create_pull_request', got '%s'", api.Actual.Output[0].Type) + } + if api.Actual.Output[0].Expect.Data.(model.CreatePullRequest).UpdatedDependencyFiles[0].Content != expectedHashedContent { + t.Errorf("expected content to be 'hello', got '%s'", api.Actual.Output[0].Expect.Data.(model.CreatePullRequest).UpdatedDependencyFiles[0].Content) + } + + // stdout should contain the original content so folks can create PRs + if !strings.Contains(stdout.String(), content) { + t.Errorf("expected stdout to contain the original content, got '%s'", stdout.String()) + } +}