Compare commits
6 Commits
bc444dd7d4
...
d019677fb0
Author | SHA1 | Date |
---|---|---|
Justin Dalrymple | d019677fb0 | |
Ignacio Gómez | afd1bd78b3 | |
Jesper Falk | 05bbc60380 | |
Jesper Falk | 6480eb86e2 | |
Justin Dalrymple | 84ac164557 | |
Justin Dalrymple | e90423f7fd |
|
@ -3,7 +3,8 @@ package backends
|
|||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"context"
|
||||
|
||||
"github.com/iegomez/mosquitto-go-auth/hashing"
|
||||
"github.com/pkg/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
@ -13,6 +14,7 @@ type localJWTChecker struct {
|
|||
db string
|
||||
postgres Postgres
|
||||
mysql Mysql
|
||||
mongo Mongo
|
||||
userQuery string
|
||||
hasher hashing.HashComparer
|
||||
options tokenOptions
|
||||
|
@ -21,10 +23,11 @@ type localJWTChecker struct {
|
|||
const (
|
||||
mysqlDB = "mysql"
|
||||
postgresDB = "postgres"
|
||||
mongoDB = "mongo"
|
||||
)
|
||||
|
||||
// NewLocalJWTChecker initializes a checker with a local DB.
|
||||
func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher hashing.HashComparer, options tokenOptions) (jwtChecker, error) {
|
||||
func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher hashing.HashComparer, options tokenOptions) (JWTChecker, error) {
|
||||
checker := &localJWTChecker{
|
||||
hasher: hasher,
|
||||
db: postgresDB,
|
||||
|
@ -35,7 +38,7 @@ func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher h
|
|||
localOk := true
|
||||
|
||||
if options.secret == "" {
|
||||
return nil, errors.New("JWT backend error: missing jwt secret")
|
||||
return nil, errors.New("JWT backend error: missing JWT secret")
|
||||
}
|
||||
|
||||
if db, ok := authOpts["jwt_db"]; ok {
|
||||
|
@ -59,21 +62,28 @@ func NewLocalJWTChecker(authOpts map[string]string, logLevel log.Level, hasher h
|
|||
if checker.db == mysqlDB {
|
||||
mysql, err := NewMysql(dbAuthOpts, logLevel, hasher)
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("JWT backend error: couldn't create mysql connector for local jwt: %s", err)
|
||||
return nil, errors.Errorf("JWT backend error: couldn't create mysql connector for local JWT: %s", err)
|
||||
}
|
||||
|
||||
checker.mysql = mysql
|
||||
} else if checker.db == mongoDB {
|
||||
mongodb, err := NewMongo(dbAuthOpts, logLevel, hasher)
|
||||
|
||||
return checker, nil
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("JWT backend error: couldn't create mysql connector for local JWT: %s", err)
|
||||
}
|
||||
|
||||
checker.mongo = mongodb
|
||||
} else {
|
||||
postgres, err := NewPostgres(dbAuthOpts, logLevel, hasher)
|
||||
|
||||
checker.postgres = postgres
|
||||
}
|
||||
|
||||
postgres, err := NewPostgres(dbAuthOpts, logLevel, hasher)
|
||||
|
||||
if err != nil {
|
||||
return nil, errors.Errorf("JWT backend error: couldn't create postgres connector for local jwt: %s", err)
|
||||
return nil, errors.Errorf("JWT backend error: couldn't create postgres connector for local JWT: %s", err)
|
||||
}
|
||||
|
||||
checker.postgres = postgres
|
||||
|
||||
return checker, nil
|
||||
}
|
||||
|
||||
|
@ -81,7 +91,7 @@ func (o *localJWTChecker) GetUser(token string) (bool, error) {
|
|||
username, err := getUsernameForToken(o.options, token, o.options.skipUserExpiration)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("jwt local get user error: %s", err)
|
||||
log.Printf("JWT local get user error: %s", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
|
@ -92,43 +102,47 @@ func (o *localJWTChecker) GetSuperuser(token string) (bool, error) {
|
|||
username, err := getUsernameForToken(o.options, token, o.options.skipUserExpiration)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("jwt local get superuser error: %s", err)
|
||||
log.Printf("JWT local get superuser error: %s", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if o.db == mysqlDB {
|
||||
return o.mysql.GetSuperuser(username)
|
||||
} else if o.db == mongoDB {
|
||||
return o.mongo.GetSuperuser(username)
|
||||
} else {
|
||||
return o.postgres.GetSuperuser(username)
|
||||
}
|
||||
|
||||
return o.postgres.GetSuperuser(username)
|
||||
}
|
||||
|
||||
func (o *localJWTChecker) CheckAcl(token, topic, clientid string, acc int32) (bool, error) {
|
||||
username, err := getUsernameForToken(o.options, token, o.options.skipACLExpiration)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("jwt local check acl error: %s", err)
|
||||
log.Printf("JWT local check acl error: %s", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if o.db == mysqlDB {
|
||||
return o.mysql.CheckAcl(username, topic, clientid, acc)
|
||||
} else if o.db == mongoDB {
|
||||
return o.mongo.CheckAcl(username)
|
||||
} else {
|
||||
return o.postgres.CheckAcl(username)
|
||||
}
|
||||
|
||||
return o.postgres.CheckAcl(username, topic, clientid, acc)
|
||||
}
|
||||
|
||||
func (o *localJWTChecker) Halt() {
|
||||
if o.postgres != (Postgres{}) && o.postgres.DB != nil {
|
||||
err := o.postgres.DB.Close()
|
||||
if err != nil {
|
||||
log.Errorf("JWT cleanup error: %s", err)
|
||||
}
|
||||
} else if o.mysql != (Mysql{}) && o.mysql.DB != nil {
|
||||
err := o.mysql.DB.Close()
|
||||
if err != nil {
|
||||
log.Errorf("JWT cleanup error: %s", err)
|
||||
}
|
||||
} else if o.mongo != (Mongo{}) && o.mongo.Conn != nil {
|
||||
err := o.mongo.Conn.Disconnect(context.TODO())
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("JWT cleanup error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -137,25 +151,36 @@ func (o *localJWTChecker) getLocalUser(username string) (bool, error) {
|
|||
return false, nil
|
||||
}
|
||||
|
||||
var count sql.NullInt64
|
||||
var err error
|
||||
var sqlCount sql.NullInt64
|
||||
var count Int64
|
||||
var valid boolean
|
||||
|
||||
if o.db == mysqlDB {
|
||||
err = o.mysql.DB.Get(&count, o.userQuery, username)
|
||||
valid = sqlCount.Valid
|
||||
count = sqlCount.Int64
|
||||
} else if o.db == mongoDB {
|
||||
var uc := o.mongo.Conn.Database(o.mongo.DBName).Collection(o.mongo.UsersCollection)
|
||||
|
||||
count, err := uc.CountDocuments(context.TODO(), bson.M{"username": username})
|
||||
} else {
|
||||
err = o.postgres.DB.Get(&count, o.userQuery, username)
|
||||
}
|
||||
valid = sqlCount.Valid
|
||||
count = sqlCount.Int64
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("local JWT get user error: %s", err)
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !count.Valid {
|
||||
if !valid {
|
||||
log.Debugf("local JWT get user error: user %s not found", username)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if count.Int64 > 0 {
|
||||
if count > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
@ -165,9 +190,12 @@ func (o *localJWTChecker) getLocalUser(username string) (bool, error) {
|
|||
func extractOpts(authOpts map[string]string, db string) map[string]string {
|
||||
dbAuthOpts := make(map[string]string)
|
||||
|
||||
dbPrefix := "pg"
|
||||
if db == mysqlDB {
|
||||
dbPrefix = mysqlDB
|
||||
} else if db == mongoDB {
|
||||
dbPrefix = mongoDB
|
||||
} else {
|
||||
dbPrefix := "pg"
|
||||
}
|
||||
|
||||
prefix := "jwt_" + dbPrefix
|
||||
|
|
|
@ -126,22 +126,29 @@ func TestArgon2ID(t *testing.T) {
|
|||
|
||||
func TestPBKDF2(t *testing.T) {
|
||||
password := "test-password"
|
||||
b64Hasher := NewPBKDF2Hasher(defaultPBKDF2SaltSize, defaultPBKDF2Iterations, defaultPBKDF2Algorithm, Base64, defaultPBKDF2KeyLen)
|
||||
utf8Hasher := NewPBKDF2Hasher(defaultPBKDF2SaltSize, defaultPBKDF2Iterations, defaultPBKDF2Algorithm, UTF8, defaultPBKDF2KeyLen)
|
||||
|
||||
// Test base64.
|
||||
hasher := NewPBKDF2Hasher(defaultPBKDF2SaltSize, defaultPBKDF2Iterations, defaultPBKDF2Algorithm, Base64, defaultPBKDF2KeyLen)
|
||||
t.Run("OlderFormat", func(t *testing.T) {
|
||||
t.Run("Base64", func(t *testing.T) {
|
||||
passwordHash, err := b64Hasher.Hash(password)
|
||||
|
||||
passwordHash, err := hasher.Hash(password)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, b64Hasher.Compare(password, passwordHash))
|
||||
assert.False(t, b64Hasher.Compare("other", passwordHash))
|
||||
})
|
||||
t.Run("UTF8", func(t *testing.T) {
|
||||
passwordHash, err := utf8Hasher.Hash(password)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, hasher.Compare(password, passwordHash))
|
||||
assert.False(t, hasher.Compare("other", passwordHash))
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, utf8Hasher.Compare(password, passwordHash))
|
||||
assert.False(t, utf8Hasher.Compare("other", passwordHash))
|
||||
})
|
||||
})
|
||||
|
||||
// Test UTF8.
|
||||
hasher = NewPBKDF2Hasher(defaultPBKDF2SaltSize, defaultPBKDF2Iterations, defaultPBKDF2Algorithm, UTF8, defaultPBKDF2KeyLen)
|
||||
|
||||
passwordHash, err = hasher.Hash(password)
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, hasher.Compare(password, passwordHash))
|
||||
assert.False(t, hasher.Compare("other", passwordHash))
|
||||
t.Run("PHC-SF-Spec", func(t *testing.T) {
|
||||
passwordHash := "$pbkdf2-sha512$i=10000,l=32$/DsNR8DBuoF/MxzLY+QVaw$YNfYNfT+6yT2blLrXKKR8Ll+aesgHYqSOtFTBsyscRM"
|
||||
assert.True(t, b64Hasher.Compare(password, passwordHash))
|
||||
assert.False(t, b64Hasher.Compare("other", passwordHash))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -23,13 +23,13 @@ type pbkdf2Hasher struct {
|
|||
keyLen int
|
||||
}
|
||||
|
||||
func NewPBKDF2Hasher(saltSize int, iterations int, algorithm string, saltEncoding string, keylen int) HashComparer {
|
||||
func NewPBKDF2Hasher(saltSize int, iterations int, algorithm string, saltEncoding string, keyLen int) HashComparer {
|
||||
return pbkdf2Hasher{
|
||||
saltSize: saltSize,
|
||||
iterations: iterations,
|
||||
algorithm: algorithm,
|
||||
saltEncoding: preferredEncoding(saltEncoding),
|
||||
keyLen: keylen,
|
||||
keyLen: keyLen,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -37,20 +37,18 @@ func NewPBKDF2Hasher(saltSize int, iterations int, algorithm string, saltEncodin
|
|||
* PBKDF2 methods are adapted from github.com/brocaar/chirpstack-application-server, some comments included.
|
||||
*/
|
||||
|
||||
// Hash function reference may be found at https://github.com/brocaar/chirpstack-application-server/blob/master/internal/storage/user.go#L421.
|
||||
|
||||
// Generate the hash of a password for storage in the database.
|
||||
// NOTE: We store the details of the hashing algorithm with the hash itself,
|
||||
// making it easy to recreate the hash for password checking, even if we change
|
||||
// the default criteria here.
|
||||
// Hash function generates a hash of the supplied password. The hash
|
||||
// can then be stored directly in the database. The return hash will
|
||||
// contain options according to the PHC String format found at
|
||||
// https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md
|
||||
func (h pbkdf2Hasher) Hash(password string) (string, error) {
|
||||
// Generate a random salt value with the given salt size.
|
||||
salt := make([]byte, h.saltSize)
|
||||
_, err := rand.Read(salt)
|
||||
|
||||
// We need to ensure that salt doesn contain $, which is 36 in decimal.
|
||||
// So we check if there'sbyte that represents $ and change it with a random number in the range 0-35
|
||||
//// This is far from ideal, but should be good enough with a reasonable salt size.
|
||||
// We need to ensure that salt doesn't contain $, which is 36 in decimal.
|
||||
// So we check if there's byte that represents $ and change it with a random number in the range 0-35
|
||||
// // This is far from ideal, but should be good enough with a reasonable salt size.
|
||||
for i := 0; i < len(salt); i++ {
|
||||
if salt[i] == 36 {
|
||||
n, err := rand.Int(rand.Reader, big.NewInt(35))
|
||||
|
@ -69,52 +67,125 @@ func (h pbkdf2Hasher) Hash(password string) (string, error) {
|
|||
return h.hashWithSalt(password, salt, h.iterations, h.algorithm, h.keyLen), nil
|
||||
}
|
||||
|
||||
// HashCompare verifies that passed password hashes to the same value as the
|
||||
// Compare verifies that passed password hashes to the same value as the
|
||||
// passed passwordHash.
|
||||
// Reference: https://github.com/brocaar/chirpstack-application-server/blob/master/internal/storage/user.go#L458.
|
||||
// Parsing reference: https://github.com/P-H-C/phc-string-format/blob/master/phc-sf-spec.md
|
||||
func (h pbkdf2Hasher) Compare(password string, passwordHash string) bool {
|
||||
hashSplit := strings.Split(passwordHash, "$")
|
||||
hashSplit := h.getFields(passwordHash)
|
||||
|
||||
if len(hashSplit) != 5 {
|
||||
log.Errorf("invalid PBKDF2 hash supplied, expected length 5, got: %d", len(hashSplit))
|
||||
return false
|
||||
}
|
||||
|
||||
algorithm := hashSplit[1]
|
||||
|
||||
iterations, err := strconv.Atoi(hashSplit[2])
|
||||
if err != nil {
|
||||
log.Errorf("iterations error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
var salt []byte
|
||||
switch h.saltEncoding {
|
||||
case UTF8:
|
||||
salt = []byte(hashSplit[3])
|
||||
default:
|
||||
salt, err = base64.StdEncoding.DecodeString(hashSplit[3])
|
||||
var (
|
||||
err error
|
||||
algorithm string
|
||||
paramString string
|
||||
hashedPassword []byte
|
||||
salt []byte
|
||||
iterations int
|
||||
keyLen int
|
||||
)
|
||||
if hashSplit[0] == "PBKDF2" {
|
||||
algorithm = hashSplit[1]
|
||||
iterations, err = strconv.Atoi(hashSplit[2])
|
||||
if err != nil {
|
||||
log.Errorf("base64 salt error: %s", err)
|
||||
log.Errorf("iterations error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
switch h.saltEncoding {
|
||||
case UTF8:
|
||||
salt = []byte(hashSplit[3])
|
||||
default:
|
||||
var err error
|
||||
salt, err = base64.StdEncoding.DecodeString(hashSplit[3])
|
||||
if err != nil {
|
||||
log.Errorf("base64 salt error: %s", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err = base64.StdEncoding.DecodeString(hashSplit[4])
|
||||
if err != nil {
|
||||
log.Errorf("base64 hash decoding error: %s", err)
|
||||
return false
|
||||
}
|
||||
keyLen = len(hashedPassword)
|
||||
|
||||
} else if hashSplit[0] == "pbkdf2-sha512" {
|
||||
algorithm = "sha512"
|
||||
paramString = hashSplit[1]
|
||||
|
||||
opts := strings.Split(paramString, ",")
|
||||
for _, opt := range opts {
|
||||
parts := strings.Split(opt, "=")
|
||||
for i := 0; i < len(parts); i += 2 {
|
||||
key := parts[i]
|
||||
val := parts[i+1]
|
||||
switch key {
|
||||
case "i":
|
||||
iterations, _ = strconv.Atoi(val)
|
||||
case "l":
|
||||
keyLen, _ = strconv.Atoi(val)
|
||||
default:
|
||||
log.Errorf("unknown options key (\"%s\")", key)
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch h.saltEncoding {
|
||||
case UTF8:
|
||||
salt = []byte(hashSplit[2])
|
||||
default:
|
||||
var err error
|
||||
salt, err = base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(hashSplit[2])
|
||||
if err != nil {
|
||||
log.Errorf("base64 salt error: %s", err)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
hashedPassword, err = base64.StdEncoding.WithPadding(base64.NoPadding).DecodeString(hashSplit[3])
|
||||
} else {
|
||||
log.Errorf("invalid PBKDF2 hash supplied, unrecognized format \"%s\"", hashSplit[0])
|
||||
return false
|
||||
}
|
||||
|
||||
newHash := h.hashWithSalt(password, salt, iterations, algorithm, keyLen)
|
||||
hashSplit = h.getFields(newHash)
|
||||
newHashedPassword, err := base64.StdEncoding.DecodeString(hashSplit[4])
|
||||
if err != nil {
|
||||
log.Errorf("base64 salt error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
return h.compareBytes(hashedPassword, newHashedPassword)
|
||||
}
|
||||
|
||||
func (h pbkdf2Hasher) compareBytes(a, b []byte) bool {
|
||||
for i, x := range a {
|
||||
if b[i] != x {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
hashedPassword, err := base64.StdEncoding.DecodeString(hashSplit[4])
|
||||
if err != nil {
|
||||
log.Errorf("base64 hash decoding error: %s", err)
|
||||
return false
|
||||
}
|
||||
|
||||
keylen := len(hashedPassword)
|
||||
|
||||
return passwordHash == h.hashWithSalt(password, salt, iterations, algorithm, keylen)
|
||||
func (h pbkdf2Hasher) getFields(passwordHash string) []string {
|
||||
hashSplit := strings.FieldsFunc(passwordHash, func(r rune) bool {
|
||||
switch r {
|
||||
case '$':
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
})
|
||||
return hashSplit
|
||||
}
|
||||
|
||||
// Reference: https://github.com/brocaar/chirpstack-application-server/blob/master/internal/storage/user.go#L432.
|
||||
func (h pbkdf2Hasher) hashWithSalt(password string, salt []byte, iterations int, algorithm string, keylen int) string {
|
||||
// Generate the hashed password. This should be a little painful, adjust ITERATIONS
|
||||
// if it needs performance tweeking. Greatly depends on the hardware.
|
||||
// if it needs performance tweaking. Greatly depends on the hardware.
|
||||
// NOTE: We store these details with the returned hashed, so changes will not
|
||||
// affect our ability to do password compares.
|
||||
shaHash := sha512.New
|
||||
|
|
Loading…
Reference in New Issue