package service

import (
	"encoding/json"
	"errors"
	"fmt"
	"gorm.io/gorm"
	"strings"
	gdb "wireguard-ui/global/client"
	"wireguard-ui/http/param"
	"wireguard-ui/http/vo"
	"wireguard-ui/model"
	"wireguard-ui/utils"
)

type client struct {
	*gorm.DB
}

func Client() client {
	return client{
		gdb.DB,
	}
}

// SaveClient
// @description: 新增/编辑客户端
// @receiver s
// @param p
// @param loginUser
// @return error
func (s client) SaveClient(p param.SaveClient, loginUser *vo.User) error {
	serverConf, err := Setting().GetWGServerForConfig()
	if err != nil {
		return err
	}

	// 对客户端IP做格式校验
	for _, cip := range p.IpAllocation {
		if !utils.Network().IPContains(serverConf.Address, cip) {
			return fmt.Errorf("客户端IP[%s]不符合定义", cip)
		}
	}

	// 处理一下endpoint
	if p.Endpoint == "" {
		globalConf, err := Setting().GetWGSetForConfig()
		if err != nil {
			return err
		}

		p.Endpoint = fmt.Sprintf("%s:%d", globalConf.EndpointAddress, serverConf.ListenPort)
	}

	keys, _ := json.Marshal(p.Keys)

	ent := &model.Client{
		Base: model.Base{
			Id: p.Id,
		},
		Name:              p.Name,
		Email:             p.Email,
		SubnetRange:       p.SubnetRange,
		IpAllocation:      strings.Join(p.IpAllocation, ","),
		AllowedIps:        strings.Join(p.AllowedIps, ","),
		ExtraAllowedIps:   strings.Join(p.ExtraAllowedIps, ","),
		Endpoint:          p.Endpoint,
		UseServerDns:      *p.UseServerDns,
		Keys:              string(keys),
		UserId:            loginUser.Id,
		Enabled:           *p.Enabled,
		OfflineMonitoring: *p.OfflineMonitoring,
	}

	// 编辑
	if p.Id != "" {
		return s.Model(&model.Client{}).Select("id", "name", "email", "subnet_range",
			"ip_allocation", "allowed_ips",
			"extra_allowed_ips", "endpoint",
			"use_server_dns", "keys", "user_id", "enabled",
			"offline_monitoring").
			Where("id = ?", ent.Id).Updates(&ent).Error
	}

	// 如果是新增,判断这个客户端IP是否已经存在过了
	var count int64
	if err := s.Model(&model.Client{}).Where("ip_allocation = ?", strings.Join(p.IpAllocation, ",")).Count(&count).Error; err != nil {
		return err
	}

	if count > 0 {
		return errors.New("当前IP客户端已经存在")
	}

	// 新增
	return s.Model(&model.Client{}).Create(&ent).Error
}

// Delete
// @description: 删除
// @receiver s
// @param id
// @return error
func (s client) Delete(id string) error {
	return s.Model(&model.Client{}).Where("id = ?", id).Delete(&model.Client{}).Error
}

// List
// @description: 客户端分页列表
// @receiver s
// @param p
// @return data
// @return total
// @return err
func (s client) List(p param.ClientList) (data []vo.ClientItem, total int64, err error) {
	sel := s.Table("t_client as tc").Scopes(Paginate(p.Current, p.Size)).
		Joins("LEFT JOIN t_user as tu ON tu.id = tc.user_id").
		Select("tc.id,tc.name,tc.email,tc.ip_allocation as ip_allocation_str,"+
			"tc.allowed_ips as allowed_ips_str,tc.extra_allowed_ips as extra_allowed_ips_str,"+
			"tc.endpoint,tc.use_server_dns,tc.keys as keys_str,tu.nickname as create_user,"+
			"tc.enabled,tc.offline_monitoring,"+
			"tc.created_at", "tc.updated_at")
	if p.Enabled != nil {
		sel.Where("tc.enabled = ?", *p.Enabled)
	}
	if p.Name != "" {
		sel.Where("tc.name like ?", "%"+p.Name+"%")
	}
	if p.Email != "" {
		sel.Where("tc.email like ?", "%"+p.Email+"%")
	}
	if p.IpAllocation != "" {
		sel.Where("tc.ip_allocation like ?", "%"+p.IpAllocation+"%")
	}
	err = sel.Order("tc.created_at DESC").Find(&data).Limit(-1).Offset(-1).Count(&total).Error
	if err != nil {
		return
	}

	for i, v := range data {
		if v.KeysStr != "" {
			_ = json.Unmarshal([]byte(v.KeysStr), &data[i].Keys)
		}
		if v.IpAllocationStr != "" {
			data[i].IpAllocation = strings.Split(v.IpAllocationStr, ",")
		}
		if v.AllowedIpsStr != "" {
			data[i].AllowedIps = strings.Split(v.AllowedIpsStr, ",")
		} else {
			data[i].AllowedIps = []string{}
		}
		if v.ExtraAllowedIpsStr != "" {
			data[i].ExtraAllowedIps = strings.Split(v.ExtraAllowedIpsStr, ",")
		} else {
			data[i].ExtraAllowedIps = []string{}
		}
	}

	return
}

// GetByID
// @description: 通过ID获取客户端
// @receiver s
// @param id
// @return data
// @return err
func (s client) GetByID(id string) (data *model.Client, err error) {
	err = s.Model(&model.Client{}).Where("id = ?", id).Take(&data).Error
	return
}

// GetByPublicKey
// @description: 通过公钥匹配客户端信息
// @receiver s
// @param pk
// @return data
// @return err
func (s client) GetByPublicKey(pk string) (data *model.Client, err error) {
	err = s.Model(&model.Client{}).Where(fmt.Sprintf("json_extract(keys, '$.publicKey') = '%s'", pk)).First(&data).Error
	return
}