mirror of https://github.com/casbin/casnode.git
Add getDefaultModelProvider()
This commit is contained in:
parent
f478db4af8
commit
43e8dabc2c
|
@ -137,7 +137,7 @@ func (c *ApiController) GetMessageAnswer() {
|
|||
return
|
||||
}
|
||||
|
||||
if provider.Category != "AI" || provider.ClientSecret == "" {
|
||||
if provider.Category != "Model" || provider.ClientSecret == "" {
|
||||
c.ResponseErrorStream(fmt.Sprintf("The provider: %s is invalid", providerId))
|
||||
return
|
||||
}
|
||||
|
|
|
@ -96,6 +96,17 @@ func UpdateChat(id string, chat *Chat) (bool, error) {
|
|||
}
|
||||
|
||||
func AddChat(chat *Chat) (bool, error) {
|
||||
if chat.Type == "AI" && chat.User2 == "" {
|
||||
provider, err := getDefaultModelProvider()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if provider != nil {
|
||||
chat.User2 = provider.Name
|
||||
}
|
||||
}
|
||||
|
||||
affected, err := adapter.engine.Insert(chat)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
|
|
@ -100,6 +100,20 @@ func GetProvider(id string) (*Provider, error) {
|
|||
return getProvider(owner, name)
|
||||
}
|
||||
|
||||
func getDefaultModelProvider() (*Provider, error) {
|
||||
provider := Provider{Owner: "admin", Category: "Model"}
|
||||
existed, err := adapter.engine.Get(&provider)
|
||||
if err != nil {
|
||||
return &provider, err
|
||||
}
|
||||
|
||||
if !existed {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &provider, nil
|
||||
}
|
||||
|
||||
func UpdateProvider(id string, provider *Provider) (bool, error) {
|
||||
owner, name := util.GetOwnerAndNameFromId(id)
|
||||
_, err := getProvider(owner, name)
|
||||
|
|
Loading…
Reference in New Issue