mosquitto-go-auth/backends/redis.go

380 lines
9.1 KiB
Go

package backends
import (
"context"
"errors"
"fmt"
"strconv"
"strings"
"time"
goredis "github.com/go-redis/redis/v8"
"github.com/iegomez/mosquitto-go-auth/hashing"
log "github.com/sirupsen/logrus"
)
type RedisClient interface {
Get(ctx context.Context, key string) *goredis.StringCmd
SMembers(ctx context.Context, key string) *goredis.StringSliceCmd
Ping(ctx context.Context) *goredis.StatusCmd
Close() error
FlushDB(ctx context.Context) *goredis.StatusCmd
Set(ctx context.Context, key string, value interface{}, expiration time.Duration) *goredis.StatusCmd
SAdd(ctx context.Context, key string, members ...interface{}) *goredis.IntCmd
Expire(ctx context.Context, key string, expiration time.Duration) *goredis.BoolCmd
ReloadState(ctx context.Context) error
}
type SingleRedisClient struct {
*goredis.Client
}
var SingleClientError = errors.New("unsupported reload state operation for Redis single client")
func (c SingleRedisClient) ReloadState(ctx context.Context) error {
return SingleClientError
}
type Redis struct {
Host string
Port string
Password string
SaltEncoding string
DB int32
conn RedisClient
disableSuperuser bool
ctx context.Context
hasher hashing.HashComparer
}
func NewRedis(authOpts map[string]string, logLevel log.Level, hasher hashing.HashComparer) (Redis, error) {
log.SetLevel(logLevel)
var redis = Redis{
Host: "localhost",
Port: "6379",
DB: 1,
SaltEncoding: "base64",
ctx: context.Background(),
hasher: hasher,
}
if authOpts["redis_disable_superuser"] == "true" {
redis.disableSuperuser = true
}
if redisHost, ok := authOpts["redis_host"]; ok {
redis.Host = redisHost
}
if redisPort, ok := authOpts["redis_port"]; ok {
redis.Port = redisPort
}
if redisPassword, ok := authOpts["redis_password"]; ok {
redis.Password = redisPassword
}
if redisDB, ok := authOpts["redis_db"]; ok {
db, err := strconv.ParseInt(redisDB, 10, 32)
if err == nil {
redis.DB = int32(db)
}
}
if authOpts["redis_mode"] == "cluster" {
addressesOpt := authOpts["redis_cluster_addresses"]
if addressesOpt == "" {
return redis, fmt.Errorf("redis backend: missing Redis Cluster addresses")
}
// Take the given addresses and trim spaces from them.
addresses := strings.Split(addressesOpt, ",")
for i := 0; i < len(addresses); i++ {
addresses[i] = strings.TrimSpace(addresses[i])
}
clusterClient := goredis.NewClusterClient(
&goredis.ClusterOptions{
Addrs: addresses,
Password: redis.Password,
})
redis.conn = clusterClient
} else {
addr := fmt.Sprintf("%s:%s", redis.Host, redis.Port)
redisClient := goredis.NewClient(&goredis.Options{
Addr: addr,
Password: redis.Password,
DB: int(redis.DB),
})
redis.conn = &SingleRedisClient{redisClient}
}
for {
if _, err := redis.conn.Ping(redis.ctx).Result(); err != nil {
log.Errorf("ping redis error, will retry in 2s: %s", err)
time.Sleep(2 * time.Second)
} else {
break
}
}
return redis, nil
}
// Checks if an error was caused by a moved record in a cluster.
func isMovedError(err error) bool {
s := err.Error()
if strings.HasPrefix(s, "MOVED ") || strings.HasPrefix(s, "ASK ") {
return true
}
return false
}
//GetUser checks that the username exists and the given password hashes to the same password.
func (o Redis) GetUser(username, password, _ string) (bool, error) {
ok, err := o.getUser(username, password)
if err == nil {
return ok, nil
}
//If using Redis Cluster, reload state and attempt once more.
if isMovedError(err) {
err = o.conn.ReloadState(o.ctx)
if err != nil {
log.Debugf("redis reload state error: %s", err)
return false, err
}
//Retry once.
ok, err = o.getUser(username, password)
}
if err != nil {
log.Debugf("redis get user error: %s", err)
}
return ok, err
}
func (o Redis) getUser(username, password string) (bool, error) {
pwHash, err := o.conn.Get(o.ctx, username).Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
if o.hasher.Compare(password, pwHash) {
return true, nil
}
return false, nil
}
//GetSuperuser checks that the key username:su exists and has value "true".
func (o Redis) GetSuperuser(username string) (bool, error) {
if o.disableSuperuser {
return false, nil
}
ok, err := o.getSuperuser(username)
if err == nil {
return ok, nil
}
//If using Redis Cluster, reload state and attempt once more.
if isMovedError(err) {
err = o.conn.ReloadState(o.ctx)
if err != nil {
log.Debugf("redis reload state error: %s", err)
return false, err
}
//Retry once.
ok, err = o.getSuperuser(username)
}
if err != nil {
log.Debugf("redis get superuser error: %s", err)
}
return ok, err
}
func (o Redis) getSuperuser(username string) (bool, error) {
isSuper, err := o.conn.Get(o.ctx, fmt.Sprintf("%s:su", username)).Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
if isSuper == "true" {
return true, nil
}
return false, nil
}
func (o Redis) CheckAcl(username, topic, clientid string, acc int32) (bool, error) {
ok, err := o.checkAcl(username, topic, clientid, acc)
if err == nil {
return ok, nil
}
//If using Redis Cluster, reload state and attempt once more.
if isMovedError(err) {
err = o.conn.ReloadState(o.ctx)
if err != nil {
log.Debugf("redis reload state error: %s", err)
return false, err
}
//Retry once.
ok, err = o.checkAcl(username, topic, clientid, acc)
}
if err != nil {
log.Debugf("redis check acl error: %s", err)
}
return ok, err
}
//CheckAcl gets all acls for the username and tries to match against topic, acc, and username/clientid if needed.
func (o Redis) checkAcl(username, topic, clientid string, acc int32) (bool, error) {
var acls []string //User specific acls.
var commonAcls []string //Common acls.
//We need to check if client is subscribing, reading or publishing to get correct acls.
switch acc {
case MOSQ_ACL_SUBSCRIBE:
//Get all user subscribe acls.
var err error
acls, err = o.conn.SMembers(o.ctx, fmt.Sprintf("%s:sacls", username)).Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
//Get common subscribe acls.
commonAcls, err = o.conn.SMembers(o.ctx, "common:sacls").Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
case MOSQ_ACL_READ:
//Get all user read and readwrite acls.
urAcls, err := o.conn.SMembers(o.ctx, fmt.Sprintf("%s:racls", username)).Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
urwAcls, err := o.conn.SMembers(o.ctx, fmt.Sprintf("%s:rwacls", username)).Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
//Get common read and readwrite acls
rAcls, err := o.conn.SMembers(o.ctx, "common:racls").Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
rwAcls, err := o.conn.SMembers(o.ctx, "common:rwacls").Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
acls = make([]string, len(urAcls)+len(urwAcls))
acls = append(acls, urAcls...)
acls = append(acls, urwAcls...)
commonAcls = make([]string, len(rAcls)+len(rwAcls))
commonAcls = append(commonAcls, rAcls...)
commonAcls = append(commonAcls, rwAcls...)
case MOSQ_ACL_WRITE:
//Get all user write and readwrite acls.
uwAcls, err := o.conn.SMembers(o.ctx, fmt.Sprintf("%s:wacls", username)).Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
urwAcls, err := o.conn.SMembers(o.ctx, fmt.Sprintf("%s:rwacls", username)).Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
//Get common write and readwrite acls
wAcls, err := o.conn.SMembers(o.ctx, "common:wacls").Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
rwAcls, err := o.conn.SMembers(o.ctx, "common:rwacls").Result()
if err == goredis.Nil {
return false, nil
} else if err != nil {
return false, err
}
acls = make([]string, len(uwAcls)+len(urwAcls))
acls = append(acls, uwAcls...)
acls = append(acls, urwAcls...)
commonAcls = make([]string, len(wAcls)+len(rwAcls))
commonAcls = append(commonAcls, wAcls...)
commonAcls = append(commonAcls, rwAcls...)
}
//Now loop through acls looking for a match.
for _, acl := range acls {
if TopicsMatch(acl, topic) {
return true, nil
}
}
for _, acl := range commonAcls {
aclTopic := strings.Replace(acl, "%c", clientid, -1)
aclTopic = strings.Replace(aclTopic, "%u", username, -1)
if TopicsMatch(aclTopic, topic) {
return true, nil
}
}
return false, nil
}
//GetName returns the backend's name
func (o Redis) GetName() string {
return "Redis"
}
//Halt terminates the connection.
func (o Redis) Halt() {
if o.conn != nil {
err := o.conn.Close()
if err != nil {
log.Errorf("Redis cleanup error: %s", err)
}
}
}