2024-03-11 14:53:28 +08:00
package repository
import (
"encoding/json"
2024-03-14 15:23:16 +08:00
"errors"
"gitee.ltd/lxh/logger/log"
2024-03-13 17:05:02 +08:00
"github.com/spf13/cast"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
2024-03-11 14:53:28 +08:00
"gorm.io/gorm"
2024-03-14 15:23:16 +08:00
"strings"
2024-03-11 14:53:28 +08:00
"wireguard-dashboard/client"
"wireguard-dashboard/http/param"
2024-03-13 17:05:02 +08:00
"wireguard-dashboard/model/entity"
"wireguard-dashboard/model/template_data"
2024-03-11 14:53:28 +08:00
"wireguard-dashboard/model/vo"
"wireguard-dashboard/utils"
)
type clientRepo struct {
* gorm . DB
}
func Client ( ) clientRepo {
return clientRepo {
client . DB ,
}
}
// List
// @description: 列表
// @receiver r
// @param p
// @return data
// @return total
// @return err
func ( r clientRepo ) List ( p param . ClientList ) ( data [ ] vo . Client , total int64 , err error ) {
err = r . Table ( "t_wg_client as twc" ) . Scopes ( utils . Page ( p . Current , p . Size ) ) . Joins ( "LEFT JOIN t_user as tu ON twc.user_id = tu.id" ) .
2024-03-14 15:23:16 +08:00
Select ( "twc.id" , "twc.created_at" , "twc.updated_at" , "twc.name" , "twc.email" , "twc.subnet_range" , "twc.ip_allocation as ip_allocation_str" , "twc.allowed_ips as allowed_ips_str" ,
"twc.extra_allowed_ips as extra_allowed_ips_str" , "twc.endpoint" , "twc.use_server_dns" , "twc.enable_after_creation" , "twc.enabled" , "twc.keys as keys_str" , "tu.name as create_user" ) .
Order ( "twc.created_at DESC" ) . Find ( & data ) . Offset ( - 1 ) . Limit ( - 1 ) . Count ( & total ) . Error
2024-03-11 14:53:28 +08:00
if err != nil {
return
}
for i , v := range data {
if v . KeysStr != "" {
_ = json . Unmarshal ( [ ] byte ( v . KeysStr ) , & data [ i ] . Keys )
}
2024-03-14 15:23:16 +08:00
if v . IpAllocationStr != "" {
data [ i ] . IpAllocation = strings . Split ( v . IpAllocationStr , "," )
}
if v . AllowedIpsStr != "" {
data [ i ] . AllowedIps = strings . Split ( v . AllowedIpsStr , "," )
}
if v . ExtraAllowedIpsStr != "" {
data [ i ] . ExtraAllowedIps = strings . Split ( v . ExtraAllowedIpsStr , "," )
}
2024-03-11 14:53:28 +08:00
}
return
}
2024-03-11 17:26:41 +08:00
// Save
// @description: 新增/编辑客户端
// @receiver r
// @param p
2024-03-13 17:05:02 +08:00
// @param adminId
2024-03-11 17:26:41 +08:00
// @return err
2024-03-13 17:05:02 +08:00
func ( r clientRepo ) Save ( p param . SaveClient , adminId string ) ( client * entity . Client , err error ) {
ent := & entity . Client {
Base : entity . Base {
Id : p . Id ,
} ,
ServerId : p . ServerId ,
Name : p . Name ,
Email : p . Email ,
SubnetRange : p . SubnetRange ,
2024-03-14 15:23:16 +08:00
IpAllocation : strings . Join ( p . IpAllocation , "," ) ,
AllowedIps : strings . Join ( p . AllowedIPS , "," ) ,
ExtraAllowedIps : strings . Join ( p . ExtraAllowedIPS , "," ) ,
2024-03-13 17:05:02 +08:00
Endpoint : p . Endpoint ,
UseServerDns : p . UseServerDNS ,
EnableAfterCreation : p . EnabledAfterCreation ,
UserId : adminId ,
Enabled : cast . ToBool ( p . Enabled ) ,
}
// id不为空, 更新信息
if p . Id != "" {
keys , _ := json . Marshal ( p . Keys )
ent . Keys = string ( keys )
if err = r . Model ( & entity . Client { } ) . Where ( "id = ?" , p . Id ) . Updates ( ent ) . Error ; err != nil {
return
}
return
}
2024-03-14 15:23:16 +08:00
// 查询新增的ip地址是否已经存在了
var count int64
if err = r . Model ( & entity . Client { } ) . Where ( "ip_allocation in (?)" , p . IpAllocation ) . Count ( & count ) . Error ; err != nil {
log . Errorf ( "查询IP地址是否存在失败: %v" , err . Error ( ) )
return
}
if count > 0 {
return nil , errors . New ( "该客户端的IP已经存在, 请检查后再添加! " )
}
2024-03-13 17:05:02 +08:00
// 为空,新增
privateKey , err := wgtypes . GeneratePrivateKey ( )
if err != nil {
return
}
publicKey := privateKey . PublicKey ( ) . String ( )
presharedKey , err := wgtypes . GenerateKey ( )
if err != nil {
return
}
keys := template_data . Keys {
2024-03-13 17:30:39 +08:00
PrivateKey : privateKey . String ( ) ,
2024-03-13 17:05:02 +08:00
PublicKey : publicKey ,
PresharedKey : presharedKey . String ( ) ,
}
keysStr , _ := json . Marshal ( keys )
ent = & entity . Client {
ServerId : p . ServerId ,
Name : p . Name ,
Email : p . Email ,
SubnetRange : p . SubnetRange ,
2024-03-14 15:23:16 +08:00
IpAllocation : strings . Join ( p . IpAllocation , "," ) ,
AllowedIps : strings . Join ( p . AllowedIPS , "," ) ,
ExtraAllowedIps : strings . Join ( p . ExtraAllowedIPS , "," ) ,
2024-03-13 17:05:02 +08:00
Endpoint : p . Endpoint ,
UseServerDns : p . UseServerDNS ,
EnableAfterCreation : p . EnabledAfterCreation ,
Keys : string ( keysStr ) ,
UserId : adminId ,
2024-03-13 17:30:39 +08:00
Enabled : true ,
2024-03-13 17:05:02 +08:00
}
err = r . Model ( & entity . Client { } ) . Create ( ent ) . Error
return
2024-03-11 17:26:41 +08:00
}
2024-03-14 15:23:16 +08:00
// Delete
// @description: 删除客户端
// @receiver r
// @param id
// @return err
func ( r clientRepo ) Delete ( id string ) ( err error ) {
return r . Model ( & entity . Client { } ) . Where ( "id = ?" , id ) . Delete ( & entity . Client { } ) . Error
}
// GetById
// @description: 根据id获取客户端详情
// @receiver r
// @param id
// @return data
// @return err
func ( r clientRepo ) GetById ( id string ) ( data entity . Client , err error ) {
err = r . Model ( & entity . Client { } ) . Where ( "id = ?" , id ) . Preload ( "Server" ) . First ( & data ) . Error
2024-03-15 15:17:40 +08:00
return
}
// GetByPublicKey
// @description: 根据公钥获取客户端信息
// @receiver r
// @param publicKey
// @return data
// @return err
func ( r clientRepo ) GetByPublicKey ( publicKey string ) ( data entity . Client , err error ) {
err = r . Model ( & entity . Client { } ) . Where ( "keys->$.publicKey = ?" , publicKey ) . Preload ( "Server" ) . First ( & data ) . Error
2024-03-14 15:23:16 +08:00
return
}