Skip to content

Commit

Permalink
Add support for certification revocation list files
Browse files Browse the repository at this point in the history
Signed-off-by: Hormoz Kheradmand <[email protected]>
  • Loading branch information
hkdsun committed Oct 13, 2021
1 parent 6b31715 commit 8e06dc7
Show file tree
Hide file tree
Showing 28 changed files with 538 additions and 65 deletions.
1 change: 1 addition & 0 deletions data/test/mysql_ldap_auth_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"LdapCert": "path/to/ldap-client-cert.pem",
"LdapKey": "path/to/ldap-client-key.pem",
"LdapCA": "path/to/ldap-client-ca.pem",
"LdapCRL": "path/to/ldap-client-crl.pem",
"User": "uid=vitessROuser,ou=users,ou=people,dc=example,dc=com",
"Password": "sUpErSeCuRe1",
"GroupQuery": "ou=groups,ou=people,dc=example,dc=com",
Expand Down
30 changes: 28 additions & 2 deletions go/cmd/vttlstest/vttlstest.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ var cmdMap map[string]cmdFunc
func init() {
cmdMap = map[string]cmdFunc{
"CreateCA": cmdCreateCA,
"CreateCRL": cmdCreateCRL,
"CreateSignedCert": cmdCreateSignedCert,
"RevokeCert": cmdRevokeCert,
}
}

Expand All @@ -65,6 +67,28 @@ func cmdCreateCA(subFlags *flag.FlagSet, args []string) {
tlstest.CreateCA(*root)
}

func cmdCreateCRL(subFlags *flag.FlagSet, args []string) {
subFlags.Parse(args)
if subFlags.NArg() != 1 {
log.Fatalf("CreateCRL command takes a single CA name as a parameter")
}

ca := subFlags.Arg(0)
tlstest.CreateCRL(*root, ca)
}

func cmdRevokeCert(subFlags *flag.FlagSet, args []string) {
parent := subFlags.String("parent", "ca", "Parent cert name to use. Use 'ca' for the toplevel CA.")

subFlags.Parse(args)
if subFlags.NArg() != 1 {
log.Fatalf("RevokeCert command takes a single name as a parameter")
}

name := subFlags.Arg(0)
tlstest.RevokeCertAndRegenerateCRL(*root, *parent, name)
}

func cmdCreateSignedCert(subFlags *flag.FlagSet, args []string) {
parent := subFlags.String("parent", "ca", "Parent cert name to use. Use 'ca' for the toplevel CA.")
serial := subFlags.String("serial", "01", "Serial number for the certificate to create. Should be different for two certificates with the same parent.")
Expand All @@ -74,11 +98,13 @@ func cmdCreateSignedCert(subFlags *flag.FlagSet, args []string) {
if subFlags.NArg() != 1 {
log.Fatalf("CreateSignedCert command takes a single name as a parameter")
}

name := subFlags.Arg(0)
if *commonName == "" {
*commonName = subFlags.Arg(0)
*commonName = name
}

tlstest.CreateSignedCert(*root, *parent, *serial, subFlags.Arg(0), *commonName)
tlstest.CreateSignedCert(*root, *parent, *serial, name, *commonName)
}

func main() {
Expand Down
4 changes: 4 additions & 0 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ func TestValidCert(t *testing.T) {
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", clientCertUsername)
tlstest.CreateCRL(root, tlstest.CA)

// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
if err != nil {
Expand Down Expand Up @@ -136,12 +138,14 @@ func TestNoCert(t *testing.T) {
defer os.RemoveAll(root)
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateCRL(root, tlstest.CA)

// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
}

// Build the TLS config.
clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, serverName, tlsVersion)
clientConfig, err := vttls.ClientConfig(params.EffectiveSslMode(), params.SslCert, params.SslKey, params.SslCa, params.SslCrl, serverName, tlsVersion)
if err != nil {
return NewSQLError(CRSSLConnectionError, SSUnknownSQLState, "error loading client cert and ca: %v", err)
}
Expand Down
12 changes: 12 additions & 0 deletions go/mysql/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ func TestTLSClientDisabled(t *testing.T) {
path.Join(root, "server-key.pem"),
"",
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
Expand Down Expand Up @@ -260,6 +261,7 @@ func TestTLSClientPreferredDefault(t *testing.T) {
path.Join(root, "server-key.pem"),
"",
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
Expand Down Expand Up @@ -381,6 +383,7 @@ func TestTLSClientVerifyCA(t *testing.T) {
path.Join(root, "server-key.pem"),
"",
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
Expand Down Expand Up @@ -465,6 +468,7 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
path.Join(root, "server-key.pem"),
"",
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
Expand Down Expand Up @@ -511,4 +515,12 @@ func TestTLSClientVerifyIdentity(t *testing.T) {
if conn != nil {
conn.Close()
}

// Now revoke the server certificate and make sure we can't connect
tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "server")

params.SslCrl = path.Join(root, "ca-crl.pem")
_, err = Connect(context.Background(), params)
require.Error(t, err)
require.Contains(t, err.Error(), "Certificate revoked: CommonName=server.example.com")
}
1 change: 1 addition & 0 deletions go/mysql/conn_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ type ConnParams struct {
SslCa string `json:"ssl_ca"`
SslCaPath string `json:"ssl_ca_path"`
SslCert string `json:"ssl_cert"`
SslCrl string `json:"ssl_crl"`
SslKey string `json:"ssl_key"`
TLSMinVersion string `json:"tls_min_version"`
ServerName string `json:"server_name"`
Expand Down
1 change: 1 addition & 0 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ func TestSSLConnection(t *testing.T) {
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
"",
"",
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
Expand Down
3 changes: 2 additions & 1 deletion go/mysql/ldapauthserver/auth_server_ldap.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ type ServerConfig struct {
LdapCert string
LdapKey string
LdapCA string
LdapCRL string
LdapTLSMinVersion string
}

Expand Down Expand Up @@ -250,7 +251,7 @@ func (lci *ClientImpl) Connect(network string, config *ServerConfig) error {
return err
}

tlsConfig, err := vttls.ClientConfig(vttls.VerifyIdentity, config.LdapCert, config.LdapKey, config.LdapCA, serverName, tlsVersion)
tlsConfig, err := vttls.ClientConfig(vttls.VerifyIdentity, config.LdapCert, config.LdapKey, config.LdapCA, config.LdapCRL, serverName, tlsVersion)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions go/mysql/mysql_fuzzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ func FuzzTLSServer(data []byte) int {
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
"",
"",
tls.VersionTLS12)
if err != nil {
return -1
Expand Down
17 changes: 16 additions & 1 deletion go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,7 @@ func TestTLSServer(t *testing.T) {
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
"",
"",
tls.VersionTLS12)
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
Expand Down Expand Up @@ -924,12 +925,16 @@ func TestTLSRequired(t *testing.T) {
defer os.RemoveAll(root)
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
tlstest.CreateSignedCert(root, tlstest.CA, "03", "revoked-client", "Revoked Client Cert")
tlstest.RevokeCertAndRegenerateCRL(root, tlstest.CA, "revoked-client")

// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
path.Join(root, "ca-crl.pem"),
"",
tls.VersionTLS12)
require.NoError(t, err)
Expand Down Expand Up @@ -966,7 +971,6 @@ func TestTLSRequired(t *testing.T) {
}

// setup conn params with TLS
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
params.SslMode = vttls.VerifyIdentity
params.SslCa = path.Join(root, "ca-cert.pem")
params.SslCert = path.Join(root, "client-cert.pem")
Expand All @@ -977,6 +981,16 @@ func TestTLSRequired(t *testing.T) {
if conn != nil {
conn.Close()
}

// setup conn params with TLS, but with a revoked client certificate
params.SslCert = path.Join(root, "revoked-client-cert.pem")
params.SslKey = path.Join(root, "revoked-client-key.pem")
conn, err = Connect(context.Background(), params)
require.NotNil(t, err)
require.Contains(t, err.Error(), "remote error: tls: bad certificate")
if conn != nil {
conn.Close()
}
}

func TestCachingSha2PasswordAuthWithTLS(t *testing.T) {
Expand Down Expand Up @@ -1013,6 +1027,7 @@ func TestCachingSha2PasswordAuthWithTLS(t *testing.T) {
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"),
"",
"",
tls.VersionTLS12)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ func tabletConnExtraArgs(name string) []string {
}

func getVitessClient(addr string) (vtgateservicepb.VitessClient, error) {
opt, err := grpcclient.SecureDialOption(grpcCert, grpcKey, grpcCa, grpcName)
opt, err := grpcclient.SecureDialOption(grpcCert, grpcKey, grpcCa, "", grpcName)
if err != nil {
return nil, err
}
Expand Down
3 changes: 2 additions & 1 deletion go/vt/binlog/grpcbinlogplayer/player.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var (
cert = flag.String("binlog_player_grpc_cert", "", "the cert to use to connect")
key = flag.String("binlog_player_grpc_key", "", "the key to use to connect")
ca = flag.String("binlog_player_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("binlog_player_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("binlog_player_grpc_server_name", "", "the server name to use to validate server certificate")
)

Expand All @@ -48,7 +49,7 @@ type client struct {
func (client *client) Dial(tablet *topodatapb.Tablet) error {
addr := netutil.JoinHostPort(tablet.Hostname, tablet.PortMap["grpc"])
var err error
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return err
}
Expand Down
4 changes: 2 additions & 2 deletions go/vt/grpcclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,15 @@ func interceptors() []grpc.DialOption {
// SecureDialOption returns the gRPC dial option to use for the
// given client connection. It is either using TLS, or Insecure if
// nothing is set.
func SecureDialOption(cert, key, ca, name string) (grpc.DialOption, error) {
func SecureDialOption(cert, key, ca, crl, name string) (grpc.DialOption, error) {
// No security options set, just return.
if (cert == "" || key == "") && ca == "" {
return grpc.WithInsecure(), nil
}

// Load the config. At this point we know
// we want a strict config with verify identity.
config, err := vttls.ClientConfig(vttls.VerifyIdentity, cert, key, ca, name, tls.VersionTLS12)
config, err := vttls.ClientConfig(vttls.VerifyIdentity, cert, key, ca, crl, name, tls.VersionTLS12)
if err != nil {
return nil, err
}
Expand Down
5 changes: 4 additions & 1 deletion go/vt/servenv/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ var (
// GRPCCA is the CA to use if TLS is enabled
GRPCCA = flag.String("grpc_ca", "", "server CA to use for gRPC connections, requires TLS, and enforces client certificate check")

// GRPCCRL is the CRL (Certificate Revocation List) to use if TLS is enabled
GRPCCRL = flag.String("grpc_crl", "", "path to a certificate revocation list in PEM format, client certificates will be further verified against this file during TLS handshake")

GRPCEnableOptionalTLS = flag.Bool("grpc_enable_optional_tls", false, "enable optional TLS mode when a server accepts both TLS and plain-text connections on the same port")

// GRPCServerCA if specified will combine server cert and server CA
Expand Down Expand Up @@ -133,7 +136,7 @@ func createGRPCServer() {

var opts []grpc.ServerOption
if GRPCPort != nil && *GRPCCert != "" && *GRPCKey != "" {
config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA, *GRPCServerCA, tls.VersionTLS12)
config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA, *GRPCCRL, *GRPCServerCA, tls.VersionTLS12)
if err != nil {
log.Exitf("Failed to log gRPC cert/key/ca: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion go/vt/throttler/grpcthrottlerclient/grpcthrottlerclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var (
cert = flag.String("throttler_client_grpc_cert", "", "the cert to use to connect")
key = flag.String("throttler_client_grpc_key", "", "the key to use to connect")
ca = flag.String("throttler_client_grpc_ca", "", "the server ca to use to validate servers when connecting")
crl = flag.String("throttler_client_grpc_crl", "", "the server crl to use to validate server certificates when connecting")
name = flag.String("throttler_client_grpc_server_name", "", "the server name to use to validate server certificate")
)

Expand All @@ -45,7 +46,7 @@ type client struct {
}

func factory(addr string) (throttlerclient.Client, error) {
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *name)
opt, err := grpcclient.SecureDialOption(*cert, *key, *ca, *crl, *name)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 8e06dc7

Please sign in to comment.