perf: 优化Eval.go的代码,添加注释,规范输出

This commit is contained in:
ZacharyZcR 2024-12-19 14:26:30 +08:00
parent 02eb3d6f7a
commit 4d3ccba255

View File

@ -24,68 +24,86 @@ import (
"time" "time"
) )
// NewEnv 创建一个新的 CEL 环境
func NewEnv(c *CustomLib) (*cel.Env, error) { func NewEnv(c *CustomLib) (*cel.Env, error) {
return cel.NewEnv(cel.Lib(c)) return cel.NewEnv(cel.Lib(c))
} }
// Evaluate 评估 CEL 表达式
func Evaluate(env *cel.Env, expression string, params map[string]interface{}) (ref.Val, error) { func Evaluate(env *cel.Env, expression string, params map[string]interface{}) (ref.Val, error) {
// 空表达式默认返回 true
if expression == "" { if expression == "" {
return types.Bool(true), nil return types.Bool(true), nil
} }
ast, iss := env.Compile(expression)
if iss.Err() != nil { // 编译表达式
//fmt.Printf("compile: ", iss.Err()) ast, issues := env.Compile(expression)
return nil, iss.Err() if issues.Err() != nil {
return nil, fmt.Errorf("表达式编译错误: %w", issues.Err())
} }
prg, err := env.Program(ast) // 创建程序
program, err := env.Program(ast)
if err != nil { if err != nil {
//fmt.Printf("Program creation error: %v", err) return nil, fmt.Errorf("程序创建错误: %w", err)
return nil, err
} }
out, _, err := prg.Eval(params) // 执行评估
result, _, err := program.Eval(params)
if err != nil { if err != nil {
//fmt.Printf("Evaluation error: %v", err) return nil, fmt.Errorf("表达式评估错误: %w", err)
return nil, err
} }
return out, nil
return result, nil
} }
// UrlTypeToString 将 URL 结构体转换为字符串
func UrlTypeToString(u *UrlType) string { func UrlTypeToString(u *UrlType) string {
var buf strings.Builder var builder strings.Builder
// 处理 scheme 部分
if u.Scheme != "" { if u.Scheme != "" {
buf.WriteString(u.Scheme) builder.WriteString(u.Scheme)
buf.WriteByte(':') builder.WriteByte(':')
} }
// 处理 host 部分
if u.Scheme != "" || u.Host != "" { if u.Scheme != "" || u.Host != "" {
if u.Host != "" || u.Path != "" { if u.Host != "" || u.Path != "" {
buf.WriteString("//") builder.WriteString("//")
} }
if h := u.Host; h != "" { if host := u.Host; host != "" {
buf.WriteString(u.Host) builder.WriteString(host)
} }
} }
// 处理 path 部分
path := u.Path path := u.Path
if path != "" && path[0] != '/' && u.Host != "" { if path != "" && path[0] != '/' && u.Host != "" {
buf.WriteByte('/') builder.WriteByte('/')
} }
if buf.Len() == 0 {
// 处理相对路径
if builder.Len() == 0 {
if i := strings.IndexByte(path, ':'); i > -1 && strings.IndexByte(path[:i], '/') == -1 { if i := strings.IndexByte(path, ':'); i > -1 && strings.IndexByte(path[:i], '/') == -1 {
buf.WriteString("./") builder.WriteString("./")
} }
} }
buf.WriteString(path) builder.WriteString(path)
// 处理查询参数
if u.Query != "" { if u.Query != "" {
buf.WriteByte('?') builder.WriteByte('?')
buf.WriteString(u.Query) builder.WriteString(u.Query)
} }
// 处理片段标识符
if u.Fragment != "" { if u.Fragment != "" {
buf.WriteByte('#') builder.WriteByte('#')
buf.WriteString(u.Fragment) builder.WriteString(u.Fragment)
} }
return buf.String()
return builder.String()
} }
type CustomLib struct { type CustomLib struct {
@ -519,179 +537,256 @@ func NewEnvOption() CustomLib {
return c return c
} }
// 声明环境中的变量类型和函数 // CompileOptions 返回环境编译选项
func (c *CustomLib) CompileOptions() []cel.EnvOption { func (c *CustomLib) CompileOptions() []cel.EnvOption {
return c.envOptions return c.envOptions
} }
// ProgramOptions 返回程序运行选项
func (c *CustomLib) ProgramOptions() []cel.ProgramOption { func (c *CustomLib) ProgramOptions() []cel.ProgramOption {
return c.programOptions return c.programOptions
} }
// UpdateCompileOptions 更新编译选项,处理不同类型的变量声明
func (c *CustomLib) UpdateCompileOptions(args StrMap) { func (c *CustomLib) UpdateCompileOptions(args StrMap) {
for _, item := range args { for _, item := range args {
k, v := item.Key, item.Value key, value := item.Key, item.Value
// 在执行之前是不知道变量的类型的,所以统一声明为字符型
// 所以randomInt虽然返回的是int型在运算中却被当作字符型进行计算需要重载string_*_string // 根据函数前缀确定变量类型
var d *exprpb.Decl var declaration *exprpb.Decl
if strings.HasPrefix(v, "randomInt") { switch {
d = decls.NewIdent(k, decls.Int, nil) case strings.HasPrefix(value, "randomInt"):
} else if strings.HasPrefix(v, "newReverse") { // randomInt 函数返回整型
d = decls.NewIdent(k, decls.NewObjectType("lib.Reverse"), nil) declaration = decls.NewIdent(key, decls.Int, nil)
} else { case strings.HasPrefix(value, "newReverse"):
d = decls.NewIdent(k, decls.String, nil) // newReverse 函数返回 Reverse 对象
declaration = decls.NewIdent(key, decls.NewObjectType("lib.Reverse"), nil)
default:
// 默认声明为字符串类型
declaration = decls.NewIdent(key, decls.String, nil)
} }
c.envOptions = append(c.envOptions, cel.Declarations(d))
c.envOptions = append(c.envOptions, cel.Declarations(declaration))
} }
} }
// 初始化随机数生成器
var randSource = rand.New(rand.NewSource(time.Now().Unix())) var randSource = rand.New(rand.NewSource(time.Now().Unix()))
// randomLowercase 生成指定长度的小写字母随机字符串
func randomLowercase(n int) string { func randomLowercase(n int) string {
lowercase := "abcdefghijklmnopqrstuvwxyz" const lowercase = "abcdefghijklmnopqrstuvwxyz"
return RandomStr(randSource, lowercase, n) return RandomStr(randSource, lowercase, n)
} }
// randomUppercase 生成指定长度的大写字母随机字符串
func randomUppercase(n int) string { func randomUppercase(n int) string {
uppercase := "ABCDEFGHIJKLMNOPQRSTUVWXYZ" const uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
return RandomStr(randSource, uppercase, n) return RandomStr(randSource, uppercase, n)
} }
// randomString 生成指定长度的随机字符串(包含大小写字母和数字)
func randomString(n int) string { func randomString(n int) string {
charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
return RandomStr(randSource, charset, n) return RandomStr(randSource, charset, n)
} }
// reverseCheck 检查 DNS 记录是否存在
func reverseCheck(r *Reverse, timeout int64) bool { func reverseCheck(r *Reverse, timeout int64) bool {
// 检查必要条件
if ceyeApi == "" || r.Domain == "" || !Common.DnsLog { if ceyeApi == "" || r.Domain == "" || !Common.DnsLog {
return false return false
} }
// 等待指定时间
time.Sleep(time.Second * time.Duration(timeout)) time.Sleep(time.Second * time.Duration(timeout))
// 提取子域名
sub := strings.Split(r.Domain, ".")[0] sub := strings.Split(r.Domain, ".")[0]
urlStr := fmt.Sprintf("http://api.ceye.io/v1/records?token=%s&type=dns&filter=%s", ceyeApi, sub)
//fmt.Println(urlStr) // 构造 API 请求 URL
req, _ := http.NewRequest("GET", urlStr, nil) apiURL := fmt.Sprintf("http://api.ceye.io/v1/records?token=%s&type=dns&filter=%s",
ceyeApi, sub)
// 创建并发送请求
req, _ := http.NewRequest("GET", apiURL, nil)
resp, err := DoRequest(req, false) resp, err := DoRequest(req, false)
if err != nil { if err != nil {
return false return false
} }
if !bytes.Contains(resp.Body, []byte(`"data": []`)) && bytes.Contains(resp.Body, []byte(`"message": "OK"`)) { // api返回结果不为空 // 检查响应内容
fmt.Println(urlStr) hasData := !bytes.Contains(resp.Body, []byte(`"data": []`))
isOK := bytes.Contains(resp.Body, []byte(`"message": "OK"`))
if hasData && isOK {
fmt.Println(apiURL)
return true return true
} }
return false return false
} }
// RandomStr 生成指定长度的随机字符串
func RandomStr(randSource *rand.Rand, letterBytes string, n int) string { func RandomStr(randSource *rand.Rand, letterBytes string, n int) string {
const ( const (
letterIdxBits = 6 // 6 bits to represent a letter index // 用 6 位比特表示一个字母索引
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits letterIdxBits = 6
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits // 生成掩码000111111
//letterBytes = "1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" letterIdxMask = 1<<letterIdxBits - 1
// 63 位能存储的字母索引数量
letterIdxMax = 63 / letterIdxBits
) )
// 预分配结果数组
randBytes := make([]byte, n) randBytes := make([]byte, n)
// 使用位操作生成随机字符串
for i, cache, remain := n-1, randSource.Int63(), letterIdxMax; i >= 0; { for i, cache, remain := n-1, randSource.Int63(), letterIdxMax; i >= 0; {
// 当可用的随机位用完时,重新获取随机数
if remain == 0 { if remain == 0 {
cache, remain = randSource.Int63(), letterIdxMax cache, remain = randSource.Int63(), letterIdxMax
} }
// 获取字符集中的随机索引
if idx := int(cache & letterIdxMask); idx < len(letterBytes) { if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
randBytes[i] = letterBytes[idx] randBytes[i] = letterBytes[idx]
i-- i--
} }
// 右移已使用的位,更新计数器
cache >>= letterIdxBits cache >>= letterIdxBits
remain-- remain--
} }
return string(randBytes) return string(randBytes)
} }
// DoRequest 执行 HTTP 请求
func DoRequest(req *http.Request, redirect bool) (*Response, error) { func DoRequest(req *http.Request, redirect bool) (*Response, error) {
if req.Body == nil || req.Body == http.NoBody { // 处理请求头
} else { if req.Body != nil && req.Body != http.NoBody {
// 设置 Content-Length
req.Header.Set("Content-Length", strconv.Itoa(int(req.ContentLength))) req.Header.Set("Content-Length", strconv.Itoa(int(req.ContentLength)))
// 如果未指定 Content-Type设置默认值
if req.Header.Get("Content-Type") == "" { if req.Header.Get("Content-Type") == "" {
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
} }
} }
var oResp *http.Response
var err error // 执行请求
var (
oResp *http.Response
err error
)
if redirect { if redirect {
oResp, err = Client.Do(req) oResp, err = Client.Do(req)
} else { } else {
oResp, err = ClientNoRedirect.Do(req) oResp, err = ClientNoRedirect.Do(req)
} }
if err != nil { if err != nil {
//fmt.Println("[-]DoRequest error: ",err) return nil, fmt.Errorf("请求执行失败: %w", err)
return nil, err
} }
defer oResp.Body.Close() defer oResp.Body.Close()
// 解析响应
resp, err := ParseResponse(oResp) resp, err := ParseResponse(oResp)
if err != nil { if err != nil {
Common.LogError("[-] ParseResponse error: " + err.Error()) Common.LogError("响应解析失败: " + err.Error())
//return nil, err
} }
return resp, err return resp, err
} }
// ParseUrl 解析 URL 并转换为自定义 URL 类型
func ParseUrl(u *url.URL) *UrlType { func ParseUrl(u *url.URL) *UrlType {
nu := &UrlType{} return &UrlType{
nu.Scheme = u.Scheme Scheme: u.Scheme,
nu.Domain = u.Hostname() Domain: u.Hostname(),
nu.Host = u.Host Host: u.Host,
nu.Port = u.Port() Port: u.Port(),
nu.Path = u.EscapedPath() Path: u.EscapedPath(),
nu.Query = u.RawQuery Query: u.RawQuery,
nu.Fragment = u.Fragment Fragment: u.Fragment,
return nu }
} }
// ParseRequest 将标准 HTTP 请求转换为自定义请求对象
func ParseRequest(oReq *http.Request) (*Request, error) { func ParseRequest(oReq *http.Request) (*Request, error) {
req := &Request{} req := &Request{
req.Method = oReq.Method Method: oReq.Method,
req.Url = ParseUrl(oReq.URL) Url: ParseUrl(oReq.URL),
header := make(map[string]string) Headers: make(map[string]string),
for k := range oReq.Header { ContentType: oReq.Header.Get("Content-Type"),
header[k] = oReq.Header.Get(k)
} }
req.Headers = header
req.ContentType = oReq.Header.Get("Content-Type") // 复制请求头
if oReq.Body == nil || oReq.Body == http.NoBody { for k := range oReq.Header {
} else { req.Headers[k] = oReq.Header.Get(k)
}
// 处理请求体
if oReq.Body != nil && oReq.Body != http.NoBody {
data, err := io.ReadAll(oReq.Body) data, err := io.ReadAll(oReq.Body)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("读取请求体失败: %w", err)
} }
req.Body = data req.Body = data
// 重新设置请求体,允许后续重复读取
oReq.Body = io.NopCloser(bytes.NewBuffer(data)) oReq.Body = io.NopCloser(bytes.NewBuffer(data))
} }
return req, nil return req, nil
} }
// ParseResponse 将标准 HTTP 响应转换为自定义响应对象
func ParseResponse(oResp *http.Response) (*Response, error) { func ParseResponse(oResp *http.Response) (*Response, error) {
var resp Response resp := Response{
header := make(map[string]string) Status: int32(oResp.StatusCode),
resp.Status = int32(oResp.StatusCode) Url: ParseUrl(oResp.Request.URL),
resp.Url = ParseUrl(oResp.Request.URL) Headers: make(map[string]string),
for k := range oResp.Header { ContentType: oResp.Header.Get("Content-Type"),
header[k] = strings.Join(oResp.Header.Values(k), ";") }
// 复制响应头,合并多值头部为分号分隔的字符串
for k := range oResp.Header {
resp.Headers[k] = strings.Join(oResp.Header.Values(k), ";")
}
// 读取并解析响应体
body, err := getRespBody(oResp)
if err != nil {
return nil, fmt.Errorf("处理响应体失败: %w", err)
} }
resp.Headers = header
resp.ContentType = oResp.Header.Get("Content-Type")
body, _ := getRespBody(oResp)
resp.Body = body resp.Body = body
return &resp, nil return &resp, nil
} }
func getRespBody(oResp *http.Response) (body []byte, err error) { // getRespBody 读取 HTTP 响应体并处理可能的 gzip 压缩
body, err = io.ReadAll(oResp.Body) func getRespBody(oResp *http.Response) ([]byte, error) {
// 读取原始响应体
body, err := io.ReadAll(oResp.Body)
if err != nil && err != io.EOF {
return nil, err
}
// 处理 gzip 压缩
if strings.Contains(oResp.Header.Get("Content-Encoding"), "gzip") { if strings.Contains(oResp.Header.Get("Content-Encoding"), "gzip") {
reader, err1 := gzip.NewReader(bytes.NewReader(body)) reader, err := gzip.NewReader(bytes.NewReader(body))
if err1 == nil { if err != nil {
body, err = io.ReadAll(reader) return body, nil // 如果解压失败,返回原始数据
} }
defer reader.Close()
decompressed, err := io.ReadAll(reader)
if err != nil && err != io.EOF {
return nil, err
}
return decompressed, nil
} }
if err == io.EOF {
err = nil return body, nil
}
return
} }