|
1 | 1 | package server |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "crypto/tls" |
| 5 | + "crypto/x509" |
4 | 6 | "fmt" |
| 7 | + "io/ioutil" |
5 | 8 | "reflect" |
6 | 9 | "strconv" |
7 | 10 | "strings" |
@@ -157,3 +160,81 @@ func Test(t *testing.T) { |
157 | 160 | } |
158 | 161 | } |
159 | 162 | } |
| 163 | + |
| 164 | +func testServerTLS(t *testing.T) *tls.Config { |
| 165 | + cert, err := tls.LoadX509KeyPair("../testdata/server.crt", "../testdata/server.key") |
| 166 | + if err != nil { |
| 167 | + t.Fatal(err) |
| 168 | + } |
| 169 | + |
| 170 | + cp := x509.NewCertPool() |
| 171 | + rootca, err := ioutil.ReadFile("../testdata/client.crt") |
| 172 | + if err != nil { |
| 173 | + t.Fatal(err) |
| 174 | + } |
| 175 | + if !cp.AppendCertsFromPEM(rootca) { |
| 176 | + t.Fatal("client cert err") |
| 177 | + } |
| 178 | + return &tls.Config{ |
| 179 | + Certificates: []tls.Certificate{cert}, |
| 180 | + ClientAuth: tls.RequireAndVerifyClientCert, |
| 181 | + ServerName: "Server", |
| 182 | + ClientCAs: cp, |
| 183 | + } |
| 184 | +} |
| 185 | + |
| 186 | +func testClientTLS(t *testing.T) *tls.Config { |
| 187 | + cert, err := tls.LoadX509KeyPair("../testdata/client.crt", "../testdata/client.key") |
| 188 | + if err != nil { |
| 189 | + t.Fatal(err) |
| 190 | + } |
| 191 | + cp := x509.NewCertPool() |
| 192 | + rootca, err := ioutil.ReadFile("../testdata/server.crt") |
| 193 | + if err != nil { |
| 194 | + t.Fatal(err) |
| 195 | + } |
| 196 | + if !cp.AppendCertsFromPEM(rootca) { |
| 197 | + t.Fatal("server cert err") |
| 198 | + } |
| 199 | + return &tls.Config{ |
| 200 | + Certificates: []tls.Certificate{cert}, |
| 201 | + ServerName: "Server", |
| 202 | + RootCAs: cp, |
| 203 | + } |
| 204 | +} |
| 205 | + |
| 206 | +func TestTLS(t *testing.T) { |
| 207 | + s, err := NewServerTLS("127.0.0.1:0", testServerTLS(t)) |
| 208 | + if err != nil { |
| 209 | + t.Fatal(err) |
| 210 | + } |
| 211 | + defer s.Close() |
| 212 | + |
| 213 | + if have := s.Addr().Port; have <= 0 { |
| 214 | + t.Fatalf("have %v, want > 0", have) |
| 215 | + } |
| 216 | + |
| 217 | + s.Register("PING", func(c *Peer, cmd string, args []string) { |
| 218 | + c.WriteInline("PONG") |
| 219 | + }) |
| 220 | + |
| 221 | + cfg := testClientTLS(t) |
| 222 | + c, err := redis.Dial("tcp", |
| 223 | + s.Addr().String(), |
| 224 | + redis.DialTLSConfig(cfg), |
| 225 | + redis.DialUseTLS(true), |
| 226 | + ) |
| 227 | + if err != nil { |
| 228 | + t.Fatal(err) |
| 229 | + } |
| 230 | + |
| 231 | + { |
| 232 | + res, err := redis.String(c.Do("PING")) |
| 233 | + if err != nil { |
| 234 | + t.Fatal(err) |
| 235 | + } |
| 236 | + if have, want := res, "PONG"; have != want { |
| 237 | + t.Errorf("have: %s, want: %s", have, want) |
| 238 | + } |
| 239 | + } |
| 240 | +} |
0 commit comments