mosquitto-go-auth/backends/grpc.go

235 lines
5.6 KiB
Go

package backends
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"strconv"
"time"
"google.golang.org/grpc/credentials/insecure"
"github.com/golang/protobuf/ptypes/empty"
grpc_logrus "github.com/grpc-ecosystem/go-grpc-middleware/logging/logrus"
gs "github.com/iegomez/mosquitto-go-auth/grpc"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
// GRPC holds a client for the service and implements the Backend interface.
type GRPC struct {
client gs.AuthServiceClient
conn *grpc.ClientConn
disableSuperuser bool
dialOptions []grpc.DialOption
hostname string
timeout int
name string
}
const defaultGRPCTimeoutMs = 500
// NewGRPC tries to connect to the gRPC service at the given host.
func NewGRPC(authOpts map[string]string, logLevel log.Level) (*GRPC, error) {
g := &GRPC{
timeout: defaultGRPCTimeoutMs,
}
if authOpts["grpc_host"] == "" || authOpts["grpc_port"] == "" {
return nil, errors.New("grpc must have a host and port")
}
if authOpts["grpc_disable_superuser"] == "true" {
g.disableSuperuser = true
}
if timeout, ok := authOpts["grpc_dial_timeout_ms"]; ok {
timeoutMs, err := strconv.Atoi(timeout)
if err != nil {
log.Warnf("invalid grpc dial timeout value: %s", err)
} else {
g.timeout = timeoutMs
}
}
caCert := authOpts["grpc_ca_cert"]
tlsCert := authOpts["grpc_tls_cert"]
tlsKey := authOpts["grpc_tls_key"]
addr := fmt.Sprintf("%s:%s", authOpts["grpc_host"], authOpts["grpc_port"])
withBlock := authOpts["grpc_fail_on_dial_error"] == "true"
options, err := setup(addr, caCert, tlsCert, tlsKey, withBlock)
if err != nil {
return nil, err
}
g.dialOptions = options
g.hostname = addr
err = g.initClient()
if err != nil {
return nil, err
}
return g, nil
}
// GetUser checks that the username exists and the given password hashes to the same password.
func (o *GRPC) GetUser(username, password, clientid string) (bool, error) {
req := gs.GetUserRequest{
Username: username,
Password: password,
Clientid: clientid,
}
resp, err := o.client.GetUser(context.Background(), &req)
if err != nil {
log.Errorf("grpc get user error: %s", err)
return false, err
}
return resp.Ok, nil
}
// GetSuperuser checks that the user is a superuser.
func (o *GRPC) GetSuperuser(username string) (bool, error) {
if o.disableSuperuser {
return false, nil
}
req := gs.GetSuperuserRequest{
Username: username,
}
resp, err := o.client.GetSuperuser(context.Background(), &req)
if err != nil {
log.Errorf("grpc get superuser error: %s", err)
return false, err
}
return resp.Ok, nil
}
// CheckAcl checks if the user has access to the given topic.
func (o *GRPC) CheckAcl(username, topic, clientid string, acc int32) (bool, error) {
req := gs.CheckAclRequest{
Username: username,
Topic: topic,
Clientid: clientid,
Acc: acc,
}
resp, err := o.client.CheckAcl(context.Background(), &req)
if err != nil {
log.Errorf("grpc check acl error: %s", err)
return false, err
}
return resp.Ok, nil
}
// GetName gets the gRPC backend's name.
func (o *GRPC) GetName() string {
if len(o.name) == 0 {
resp, err := o.client.GetName(context.Background(), &empty.Empty{})
if err != nil {
o.name = "gRPC"
} else {
o.name = resp.Name
}
}
return o.name
}
// Halt signals the gRPC backend that mosquitto is halting.
func (o *GRPC) Halt() {
_, err := o.client.Halt(context.Background(), &empty.Empty{})
if err != nil {
log.Errorf("grpc halt: %s", err)
}
if o.conn != nil {
o.conn.Close()
}
}
func setup(hostname string, caCert string, tlsCert string, tlsKey string, withBlock bool) ([]grpc.DialOption, error) {
logrusEntry := log.NewEntry(log.StandardLogger())
logrusOpts := []grpc_logrus.Option{
grpc_logrus.WithLevels(grpc_logrus.DefaultCodeToLevel),
}
nsOpts := []grpc.DialOption{
grpc.WithUnaryInterceptor(
grpc_logrus.UnaryClientInterceptor(logrusEntry, logrusOpts...),
),
}
if withBlock {
nsOpts = append(nsOpts, grpc.WithBlock())
}
if len(caCert) == 0 {
nsOpts = append(nsOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
log.WithField("server", hostname).Warning("creating insecure grpc client")
} else {
log.WithField("server", hostname).Info("creating grpc client")
caCertBytes, err := ioutil.ReadFile(caCert)
if err != nil {
return nil, errors.Wrap(err, fmt.Sprintf("could not load grpc ca certificate (grpc_ca_cert) from file (%s)", caCert))
}
caCertPool := x509.NewCertPool()
if !caCertPool.AppendCertsFromPEM(caCertBytes) {
return nil, errors.New("append ca cert to pool error. Maybe the ca file (grpc_ca_cert) does not contain a valid x509 certificate")
}
tlsConfig := &tls.Config{
RootCAs: caCertPool,
}
if len(tlsCert) != 0 && len(tlsKey) != 0 {
cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey)
if err != nil {
return nil, errors.Wrap(err, "load x509 keypair error")
}
certificates := []tls.Certificate{cert}
tlsConfig.Certificates = certificates
} else if len(tlsCert) != 0 || len(tlsKey) != 0 {
log.Warn("gRPC backend warning: mutual TLS was disabled due to missing client certificate (grpc_tls_cert) or client key (grpc_tls_key)")
}
nsOpts = append(nsOpts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)))
}
return nsOpts, nil
}
func (g *GRPC) initClient() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(g.timeout)*time.Millisecond)
defer cancel()
gsClient, err := grpc.DialContext(ctx, g.hostname, g.dialOptions...)
if err != nil {
return err
}
g.conn = gsClient
g.client = gs.NewAuthServiceClient(gsClient)
return nil
}