package repository import ( "encoding/json" "errors" "fmt" "gitee.ltd/lxh/logger/log" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "gorm.io/gorm" "strings" "wireguard-dashboard/client" "wireguard-dashboard/http/param" "wireguard-dashboard/model/entity" "wireguard-dashboard/model/template_data" "wireguard-dashboard/model/vo" "wireguard-dashboard/utils" ) type clientRepo struct { *gorm.DB } func Client() clientRepo { return clientRepo{ client.DB, } } // List // @description: 列表 // @receiver r // @param p // @return data // @return total // @return err func (r clientRepo) List(p param.ClientList) (data []vo.Client, total int64, err error) { sel := r.Table("t_wg_client as twc"). Scopes(utils.Page(p.Current, p.Size)). Joins("LEFT JOIN t_user as tu ON twc.user_id = tu.id"). Select("twc.id", "twc.created_at", "twc.updated_at", "twc.name", "twc.email", "twc.subnet_range", "twc.ip_allocation as ip_allocation_str", "twc.allowed_ips as allowed_ips_str", "twc.extra_allowed_ips as extra_allowed_ips_str", "twc.endpoint", "twc.use_server_dns", "twc.enable_after_creation", "twc.enabled", "twc.keys as keys_str", "tu.name as create_user", "twc.offline_monitoring") if p.Name != "" { sel.Where("twc.name LIKE ?", "%"+p.Name+"%") } if p.Email != "" { sel.Where("twc.email = ?", p.Email) } if p.Ip != "" { sel.Where("twc.ip_allocation LIKE ?", "%"+p.Ip+"%") } if p.Enabled != nil { sel.Where("twc.enabled = ?", p.Enabled) } err = sel.Order("twc.created_at DESC").Find(&data).Offset(-1).Limit(-1).Count(&total).Error if err != nil { return } for i, v := range data { if v.KeysStr != "" { _ = json.Unmarshal([]byte(v.KeysStr), &data[i].Keys) } if v.IpAllocationStr != "" { data[i].IpAllocation = strings.Split(v.IpAllocationStr, ",") } if v.AllowedIpsStr != "" { data[i].AllowedIps = strings.Split(v.AllowedIpsStr, ",") } if v.ExtraAllowedIpsStr != "" { data[i].ExtraAllowedIps = strings.Split(v.ExtraAllowedIpsStr, ",") } } return } // Save // @description: 新增/编辑客户端 // @receiver r // @param p // @param adminId // @return err func (r clientRepo) Save(p param.SaveClient, adminId string) (client *entity.Client, err error) { ent := &entity.Client{ Base: entity.Base{ Id: p.Id, }, ServerId: p.ServerId, Name: p.Name, Email: p.Email, SubnetRange: p.SubnetRange, IpAllocation: strings.Join(p.IpAllocation, ","), AllowedIps: strings.Join(p.AllowedIPS, ","), ExtraAllowedIps: strings.Join(p.ExtraAllowedIPS, ","), Endpoint: p.Endpoint, UseServerDns: p.UseServerDNS, EnableAfterCreation: p.EnabledAfterCreation, UserId: adminId, Enabled: p.Enabled, OfflineMonitoring: p.OfflineMonitoring, } // id不为空,更新信息 if p.Id != "" { keys, _ := json.Marshal(p.Keys) ent.Keys = string(keys) if err = r.Model(&entity.Client{}). Where("id = ?", p.Id).Select("name", "email", "subnet_range", "ip_allocation", "allowed_ips", "extra_allowed_ips", "endpoint", "use_server_dns", "enable_after_creation", "user_id", "enabled", "offline_monitoring"). Updates(ent).Error; err != nil { return } return } // 查询新增的ip地址是否已经存在了 var count int64 if err = r.Model(&entity.Client{}).Where("ip_allocation in (?)", p.IpAllocation).Count(&count).Error; err != nil { log.Errorf("查询IP地址是否存在失败: %v", err.Error()) return } if count > 0 { return nil, errors.New("该客户端的IP已经存在,请检查后再添加!") } var privateKey, presharedKey wgtypes.Key var publicKey string if p.Keys.PrivateKey == "" { // 为空,新增 privateKey, err = wgtypes.GeneratePrivateKey() if err != nil { log.Errorf("生成密钥对失败: %v", err.Error()) return nil, errors.New("解析密钥失败") } } else { privateKey, err = wgtypes.ParseKey(p.Keys.PrivateKey) if err != nil { log.Errorf("解析密钥对失败: %v", err.Error()) return nil, errors.New("解析密钥失败") } } publicKey = privateKey.PublicKey().String() if p.Keys.PresharedKey == "" { presharedKey, err = wgtypes.GenerateKey() if err != nil { log.Errorf("生成共享密钥失败: %v", err.Error()) return nil, errors.New("解析密钥失败") } } else { presharedKey, err = wgtypes.ParseKey(p.Keys.PresharedKey) if err != nil { log.Errorf("解析共享密钥失败: %v", err.Error()) return nil, errors.New("解析密钥失败") } } keys := template_data.Keys{ PrivateKey: privateKey.String(), PublicKey: publicKey, PresharedKey: presharedKey.String(), } keysStr, _ := json.Marshal(keys) ent = &entity.Client{ ServerId: p.ServerId, Name: p.Name, Email: p.Email, SubnetRange: p.SubnetRange, IpAllocation: strings.Join(p.IpAllocation, ","), AllowedIps: strings.Join(p.AllowedIPS, ","), ExtraAllowedIps: strings.Join(p.ExtraAllowedIPS, ","), Endpoint: p.Endpoint, UseServerDns: p.UseServerDNS, EnableAfterCreation: p.EnabledAfterCreation, Keys: string(keysStr), UserId: adminId, Enabled: p.Enabled, OfflineMonitoring: p.OfflineMonitoring, } err = r.Model(&entity.Client{}).Create(ent).Error return } // Delete // @description: 删除客户端 // @receiver r // @param id // @return err func (r clientRepo) Delete(id string) (err error) { return r.Model(&entity.Client{}).Where("id = ?", id).Delete(&entity.Client{}).Error } // GetById // @description: 根据id获取客户端详情 // @receiver r // @param id // @return data // @return err func (r clientRepo) GetById(id string) (data entity.Client, err error) { err = r.Model(&entity.Client{}).Where("id = ?", id).Preload("Server").First(&data).Error return } // GetByPublicKey // @description: 根据公钥获取客户端信息 // @receiver r // @param publicKey // @return data // @return err func (r clientRepo) GetByPublicKey(publicKey string) (data entity.Client, err error) { err = r.Model(&entity.Client{}).Where(fmt.Sprintf("json_extract(keys, '$.publicKey') = '%s'", publicKey)).Preload("Server").First(&data).Error return } // Disabled // @description: 禁用客户端 // @receiver r // @param id // @return err func (r clientRepo) Disabled(id string) (err error) { return r.Model(&entity.Client{}).Where("id = ?", id).Update("enabled", 0).Error }