package repository

import (
	"encoding/json"
	"errors"
	"gitee.ltd/lxh/logger/log"
	"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
	"gorm.io/gorm"
	"strings"
	"wireguard-dashboard/client"
	"wireguard-dashboard/http/param"
	"wireguard-dashboard/model/entity"
	"wireguard-dashboard/model/template_data"
	"wireguard-dashboard/model/vo"
	"wireguard-dashboard/utils"
)

type clientRepo struct {
	*gorm.DB
}

func Client() clientRepo {
	return clientRepo{
		client.DB,
	}
}

// List
// @description: 列表
// @receiver r
// @param p
// @return data
// @return total
// @return err
func (r clientRepo) List(p param.ClientList) (data []vo.Client, total int64, err error) {
	sel := r.Table("t_wg_client as twc").
		Scopes(utils.Page(p.Current, p.Size)).
		Joins("LEFT JOIN t_user as tu ON twc.user_id = tu.id").
		Select("twc.id", "twc.created_at", "twc.updated_at", "twc.name", "twc.email", "twc.subnet_range", "twc.ip_allocation as ip_allocation_str", "twc.allowed_ips as allowed_ips_str",
			"twc.extra_allowed_ips as extra_allowed_ips_str", "twc.endpoint", "twc.use_server_dns", "twc.enable_after_creation", "twc.enabled", "twc.keys as keys_str", "tu.name as create_user")

	if p.Name != "" {
		sel.Where("twc.name LIKE ?", "%"+p.Name+"%")
	}

	if p.Email != "" {
		sel.Where("twc.email = ?", p.Email)
	}

	if p.Ip != "" {
		sel.Where("twc.ip_allocation LIKE ?", "%"+p.Ip+"%")
	}

	if p.Enabled != nil {
		sel.Where("twc.enabled = ?", p.Enabled)
	}

	err = sel.Order("twc.created_at DESC").Find(&data).Offset(-1).Limit(-1).Count(&total).Error

	if err != nil {
		return
	}

	for i, v := range data {
		if v.KeysStr != "" {
			_ = json.Unmarshal([]byte(v.KeysStr), &data[i].Keys)
		}
		if v.IpAllocationStr != "" {
			data[i].IpAllocation = strings.Split(v.IpAllocationStr, ",")
		}
		if v.AllowedIpsStr != "" {
			data[i].AllowedIps = strings.Split(v.AllowedIpsStr, ",")
		}
		if v.ExtraAllowedIpsStr != "" {
			data[i].ExtraAllowedIps = strings.Split(v.ExtraAllowedIpsStr, ",")
		}
	}

	return
}

// Save
// @description: 新增/编辑客户端
// @receiver r
// @param p
// @param adminId
// @return err
func (r clientRepo) Save(p param.SaveClient, adminId string) (client *entity.Client, err error) {

	ent := &entity.Client{
		Base: entity.Base{
			Id: p.Id,
		},
		ServerId:            p.ServerId,
		Name:                p.Name,
		Email:               p.Email,
		SubnetRange:         p.SubnetRange,
		IpAllocation:        strings.Join(p.IpAllocation, ","),
		AllowedIps:          strings.Join(p.AllowedIPS, ","),
		ExtraAllowedIps:     strings.Join(p.ExtraAllowedIPS, ","),
		Endpoint:            p.Endpoint,
		UseServerDns:        p.UseServerDNS,
		EnableAfterCreation: p.EnabledAfterCreation,
		UserId:              adminId,
		Enabled:             p.Enabled,
	}

	// id不为空,更新信息
	if p.Id != "" {
		keys, _ := json.Marshal(p.Keys)
		ent.Keys = string(keys)
		if err = r.Model(&entity.Client{}).
			Where("id = ?", p.Id).Select("name", "email", "subnet_range", "ip_allocation",
			"allowed_ips", "extra_allowed_ips", "endpoint", "use_server_dns", "enable_after_creation",
			"user_id", "enabled").
			Updates(ent).Error; err != nil {
			return
		}
		return
	}

	// 查询新增的ip地址是否已经存在了
	var count int64
	if err = r.Model(&entity.Client{}).Where("ip_allocation in (?)", p.IpAllocation).Count(&count).Error; err != nil {
		log.Errorf("查询IP地址是否存在失败: %v", err.Error())
		return
	}

	if count > 0 {
		return nil, errors.New("该客户端的IP已经存在,请检查后再添加!")
	}

	// 为空,新增
	privateKey, err := wgtypes.GeneratePrivateKey()
	if err != nil {
		return
	}
	publicKey := privateKey.PublicKey().String()
	presharedKey, err := wgtypes.GenerateKey()
	if err != nil {
		return
	}
	keys := template_data.Keys{
		PrivateKey:   privateKey.String(),
		PublicKey:    publicKey,
		PresharedKey: presharedKey.String(),
	}
	keysStr, _ := json.Marshal(keys)

	ent = &entity.Client{
		ServerId:            p.ServerId,
		Name:                p.Name,
		Email:               p.Email,
		SubnetRange:         p.SubnetRange,
		IpAllocation:        strings.Join(p.IpAllocation, ","),
		AllowedIps:          strings.Join(p.AllowedIPS, ","),
		ExtraAllowedIps:     strings.Join(p.ExtraAllowedIPS, ","),
		Endpoint:            p.Endpoint,
		UseServerDns:        p.UseServerDNS,
		EnableAfterCreation: p.EnabledAfterCreation,
		Keys:                string(keysStr),
		UserId:              adminId,
		Enabled:             p.Enabled,
	}

	err = r.Model(&entity.Client{}).Create(ent).Error
	return
}

// Delete
// @description: 删除客户端
// @receiver r
// @param id
// @return err
func (r clientRepo) Delete(id string) (err error) {
	return r.Model(&entity.Client{}).Where("id = ?", id).Delete(&entity.Client{}).Error
}

// GetById
// @description: 根据id获取客户端详情
// @receiver r
// @param id
// @return data
// @return err
func (r clientRepo) GetById(id string) (data entity.Client, err error) {
	err = r.Model(&entity.Client{}).Where("id = ?", id).Preload("Server").First(&data).Error
	return
}

// GetByPublicKey
// @description: 根据公钥获取客户端信息
// @receiver r
// @param publicKey
// @return data
// @return err
func (r clientRepo) GetByPublicKey(publicKey string) (data entity.Client, err error) {
	err = r.Model(&entity.Client{}).Where("keys->$.publicKey = ?", publicKey).Preload("Server").First(&data).Error
	return
}

// Disabled
// @description: 禁用客户端
// @receiver r
// @param id
// @return err
func (r clientRepo) Disabled(id string) (err error) {
	return r.Model(&entity.Client{}).Where("id = ?", id).Update("status", 0).Error
}