mosquitto-go-auth/go-auth.go

405 lines
9.9 KiB
Go

package main
import "C"
import (
"context"
"os"
"strconv"
"strings"
"time"
bes "github.com/iegomez/mosquitto-go-auth/backends"
"github.com/iegomez/mosquitto-go-auth/cache"
"github.com/iegomez/mosquitto-go-auth/hashing"
log "github.com/sirupsen/logrus"
)
type AuthPlugin struct {
backends *bes.Backends
useCache bool
logLevel log.Level
logDest string
logFile string
ctx context.Context
cache cache.Store
hasher hashing.HashComparer
retryCount int
}
// errors to signal mosquitto
const (
AuthRejected = 0
AuthGranted = 1
AuthError = 2
)
var authOpts map[string]string //Options passed by mosquitto.
var authPlugin AuthPlugin //General struct with options and conf.
//export AuthPluginInit
func AuthPluginInit(keys []*C.char, values []*C.char, authOptsNum int, version *C.char) {
log.SetFormatter(&log.TextFormatter{
FullTimestamp: true,
})
//Initialize auth plugin struct with default and given values.
authPlugin = AuthPlugin{
logLevel: log.InfoLevel,
ctx: context.Background(),
}
authOpts = make(map[string]string)
for i := 0; i < authOptsNum; i++ {
authOpts[C.GoString(keys[i])] = C.GoString(values[i])
}
if retryCount, ok := authOpts["retry_count"]; ok {
retry, err := strconv.ParseInt(retryCount, 10, 64)
if err == nil {
authPlugin.retryCount = int(retry)
} else {
log.Warningf("couldn't parse retryCount (err: %s), defaulting to 0", err)
}
}
//Check if log level is given. Set level if any valid option is given.
if logLevel, ok := authOpts["log_level"]; ok {
logLevel = strings.Replace(logLevel, " ", "", -1)
switch logLevel {
case "debug":
authPlugin.logLevel = log.DebugLevel
case "info":
authPlugin.logLevel = log.InfoLevel
case "warn":
authPlugin.logLevel = log.WarnLevel
case "error":
authPlugin.logLevel = log.ErrorLevel
case "fatal":
authPlugin.logLevel = log.FatalLevel
case "panic":
authPlugin.logLevel = log.PanicLevel
default:
log.Info("log_level unkwown, using default info level")
}
}
if logDest, ok := authOpts["log_dest"]; ok {
switch logDest {
case "stdout":
log.SetOutput(os.Stdout)
case "file":
if logFile, ok := authOpts["log_file"]; ok {
file, err := os.OpenFile(logFile, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err == nil {
log.SetOutput(file)
} else {
log.Errorf("failed to log to file, using default stderr: %s", err)
}
}
default:
log.Info("log_dest unknown, using default stderr")
}
}
var err error
authPlugin.backends, err = bes.Initialize(authOpts, authPlugin.logLevel, C.GoString(version))
if err != nil {
log.Fatalf("error initializing backends: %s", err)
}
if cache, ok := authOpts["cache"]; ok && strings.Replace(cache, " ", "", -1) == "true" {
log.Info("redisCache activated")
authPlugin.useCache = true
} else {
log.Info("No cache set.")
authPlugin.useCache = false
}
if authPlugin.useCache {
setCache(authOpts)
}
}
func setCache(authOpts map[string]string) {
var aclCacheSeconds int64 = 30
var authCacheSeconds int64 = 30
var authJitterSeconds int64 = 0
var aclJitterSeconds int64 = 0
if authCacheSec, ok := authOpts["auth_cache_seconds"]; ok {
authSec, err := strconv.ParseInt(authCacheSec, 10, 64)
if err == nil {
authCacheSeconds = authSec
} else {
log.Warningf("couldn't parse authCacheSeconds (err: %s), defaulting to %d", err, authCacheSeconds)
}
}
if authJitterSec, ok := authOpts["auth_jitter_seconds"]; ok {
authSec, err := strconv.ParseInt(authJitterSec, 10, 64)
if err == nil {
authJitterSeconds = authSec
} else {
log.Warningf("couldn't parse authJitterSeconds (err: %s), defaulting to %d", err, authJitterSeconds)
}
}
if authJitterSeconds > authCacheSeconds {
authJitterSeconds = authCacheSeconds
log.Warningf("authJitterSeconds is larger than authCacheSeconds, defaulting to %d", authJitterSeconds)
}
if aclCacheSec, ok := authOpts["acl_cache_seconds"]; ok {
aclSec, err := strconv.ParseInt(aclCacheSec, 10, 64)
if err == nil {
aclCacheSeconds = aclSec
} else {
log.Warningf("couldn't parse aclCacheSeconds (err: %s), defaulting to %d", err, aclCacheSeconds)
}
}
if aclJitterSec, ok := authOpts["acl_jitter_seconds"]; ok {
aclSec, err := strconv.ParseInt(aclJitterSec, 10, 64)
if err == nil {
aclJitterSeconds = aclSec
} else {
log.Warningf("couldn't parse aclJitterSeconds (err: %s), defaulting to %d", err, aclJitterSeconds)
}
}
if aclJitterSeconds > aclCacheSeconds {
aclJitterSeconds = aclCacheSeconds
log.Warningf("aclJitterSeconds is larger than aclCacheSeconds, defaulting to %d", aclJitterSeconds)
}
reset := false
if cacheReset, ok := authOpts["cache_reset"]; ok && cacheReset == "true" {
reset = true
}
refreshExpiration := false
if refresh, ok := authOpts["cache_refresh"]; ok && refresh == "true" {
refreshExpiration = true
}
switch authOpts["cache_type"] {
case "redis":
host := "localhost"
port := "6379"
db := 3
password := ""
cluster := false
if authOpts["cache_mode"] == "true" {
cluster = true
}
if cachePassword, ok := authOpts["cache_password"]; ok {
password = cachePassword
}
if cluster {
addressesOpt := authOpts["redis_cluster_addresses"]
if addressesOpt == "" {
log.Errorln("cache Redis cluster addresses missing, defaulting to no cache.")
authPlugin.useCache = false
return
}
// Take the given addresses and trim spaces from them.
addresses := strings.Split(addressesOpt, ",")
for i := 0; i < len(addresses); i++ {
addresses[i] = strings.TrimSpace(addresses[i])
}
authPlugin.cache = cache.NewRedisClusterStore(
password,
addresses,
time.Duration(authCacheSeconds)*time.Second,
time.Duration(aclCacheSeconds)*time.Second,
time.Duration(authJitterSeconds)*time.Second,
time.Duration(aclJitterSeconds)*time.Second,
refreshExpiration,
)
} else {
if cacheHost, ok := authOpts["cache_host"]; ok {
host = cacheHost
}
if cachePort, ok := authOpts["cache_port"]; ok {
port = cachePort
}
if cacheDB, ok := authOpts["cache_db"]; ok {
parsedDB, err := strconv.ParseInt(cacheDB, 10, 32)
if err == nil {
db = int(parsedDB)
} else {
log.Warningf("couldn't parse cache db (err: %s), defaulting to %d", err, db)
}
}
authPlugin.cache = cache.NewSingleRedisStore(
host,
port,
password,
db,
time.Duration(authCacheSeconds)*time.Second,
time.Duration(aclCacheSeconds)*time.Second,
time.Duration(authJitterSeconds)*time.Second,
time.Duration(aclJitterSeconds)*time.Second,
refreshExpiration,
)
}
default:
authPlugin.cache = cache.NewGoStore(
time.Duration(authCacheSeconds)*time.Second,
time.Duration(aclCacheSeconds)*time.Second,
time.Duration(authJitterSeconds)*time.Second,
time.Duration(aclJitterSeconds)*time.Second,
refreshExpiration,
)
}
if !authPlugin.cache.Connect(authPlugin.ctx, reset) {
authPlugin.cache = nil
authPlugin.useCache = false
log.Infoln("couldn't start cache, defaulting to no cache")
}
}
//export AuthUnpwdCheck
func AuthUnpwdCheck(username, password, clientid *C.char) uint8 {
var ok bool
var err error
for try := 0; try <= authPlugin.retryCount; try++ {
ok, err = authUnpwdCheck(C.GoString(username), C.GoString(password), C.GoString(clientid))
if err == nil {
break
}
}
if err != nil {
log.Error(err)
return AuthError
}
if ok {
return AuthGranted
}
return AuthRejected
}
func authUnpwdCheck(username, password, clientid string) (bool, error) {
var authenticated bool
var cached bool
var granted bool
var err error
if authPlugin.useCache {
log.Debugf("checking auth cache for %s", username)
cached, granted = authPlugin.cache.CheckAuthRecord(authPlugin.ctx, username, password)
if cached {
log.Debugf("found in cache: %s", username)
return granted, nil
}
}
authenticated, err = authPlugin.backends.AuthUnpwdCheck(username, password, clientid)
if authPlugin.useCache && err == nil {
authGranted := "false"
if authenticated {
authGranted = "true"
}
log.Debugf("setting auth cache for %s", username)
if setAuthErr := authPlugin.cache.SetAuthRecord(authPlugin.ctx, username, password, authGranted); setAuthErr != nil {
log.Errorf("set auth cache: %s", setAuthErr)
return false, setAuthErr
}
}
return authenticated, err
}
//export AuthAclCheck
func AuthAclCheck(clientid, username, topic *C.char, acc C.int) uint8 {
var ok bool
var err error
for try := 0; try <= authPlugin.retryCount; try++ {
ok, err = authAclCheck(C.GoString(clientid), C.GoString(username), C.GoString(topic), int(acc))
if err == nil {
break
}
}
if err != nil {
log.Error(err)
return AuthError
}
if ok {
return AuthGranted
}
return AuthRejected
}
func authAclCheck(clientid, username, topic string, acc int) (bool, error) {
var aclCheck bool
var cached bool
var granted bool
var err error
if authPlugin.useCache {
log.Debugf("checking acl cache for %s", username)
cached, granted = authPlugin.cache.CheckACLRecord(authPlugin.ctx, username, topic, clientid, acc)
if cached {
log.Debugf("found in cache: %s", username)
return granted, nil
}
}
aclCheck, err = authPlugin.backends.AuthAclCheck(clientid, username, topic, acc)
if authPlugin.useCache && err == nil {
authGranted := "false"
if aclCheck {
authGranted = "true"
}
log.Debugf("setting acl cache (granted = %s) for %s", authGranted, username)
if setACLErr := authPlugin.cache.SetACLRecord(authPlugin.ctx, username, topic, clientid, acc, authGranted); setACLErr != nil {
log.Errorf("set acl cache: %s", setACLErr)
return false, setACLErr
}
}
log.Debugf("Acl is %t for user %s", aclCheck, username)
return aclCheck, err
}
//export AuthPskKeyGet
func AuthPskKeyGet() bool {
return true
}
//export AuthPluginCleanup
func AuthPluginCleanup() {
log.Info("Cleaning up plugin")
//If cache is set, close cache connection.
if authPlugin.cache != nil {
authPlugin.cache.Close()
}
authPlugin.backends.Halt()
}
func main() {}