perf: 日常优化

This commit is contained in:
ZacharyZcR 2025-05-05 04:00:35 +08:00
parent 2b4a4024b8
commit 0dc4a6c360
5 changed files with 809 additions and 786 deletions

View File

@ -1,262 +1,151 @@
package Core
import (
"context"
"fmt"
"github.com/shadow1ng/fscan/Common"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"net"
"sort"
"strings"
"sync"
"sync/atomic"
"time"
)
// Addr 表示待扫描的地址
type Addr struct {
ip string // IP地址
port int // 端口号
}
// ScanResult 扫描结果
type ScanResult struct {
Address string // IP地址
Port int // 端口号
Service *ServiceInfo // 服务信息
}
// PortScan 执行端口扫描
// hostslist: 待扫描的主机列表
// ports: 待扫描的端口范围
// timeout: 超时时间(秒)
// 返回活跃地址列表
func PortScan(hostslist []string, ports string, timeout int64) []string {
var results []ScanResult
var aliveAddrs []string
var mu sync.Mutex
// 解析并验证端口列表
probePorts := Common.ParsePort(ports)
if len(probePorts) == 0 {
Common.LogError(fmt.Sprintf("端口格式错误: %s", ports))
return aliveAddrs
// EnhancedPortScan 高性能端口扫描函数
func EnhancedPortScan(hosts []string, ports string, timeout int64) []string {
// 解析端口和排除端口
portList := Common.ParsePort(ports)
if len(portList) == 0 {
Common.LogError("无效端口: " + ports)
return nil
}
// 排除指定端口
probePorts = excludeNoPorts(probePorts)
exclude := make(map[int]struct{})
for _, p := range Common.ParsePort(Common.ExcludePorts) {
exclude[p] = struct{}{}
}
// 初始化并发控制
workers := Common.ThreadNum
addrs := make(chan Addr, 100) // 待扫描地址通道
scanResults := make(chan ScanResult, 100) // 扫描结果通道
var wg sync.WaitGroup
var workerWg sync.WaitGroup
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
to := time.Duration(timeout) * time.Second
sem := semaphore.NewWeighted(int64(Common.ThreadNum))
var count int64
var aliveMap sync.Map
g, ctx := errgroup.WithContext(ctx)
// 启动扫描工作协程
for i := 0; i < workers; i++ {
workerWg.Add(1)
go func() {
defer workerWg.Done()
for addr := range addrs {
PortConnect(addr, scanResults, timeout, &wg)
}
}()
// 并发扫描所有目标
for _, host := range hosts {
for _, port := range portList {
if _, excluded := exclude[port]; excluded {
continue
}
// 启动结果处理协程
var resultWg sync.WaitGroup
resultWg.Add(1)
go func() {
defer resultWg.Done()
for result := range scanResults {
mu.Lock()
results = append(results, result)
aliveAddr := fmt.Sprintf("%s:%d", result.Address, result.Port)
aliveAddrs = append(aliveAddrs, aliveAddr)
mu.Unlock()
}
}()
host, port := host, port // 捕获循环变量
addr := fmt.Sprintf("%s:%d", host, port)
// 分发扫描任务
for _, port := range probePorts {
for _, host := range hostslist {
wg.Add(1)
addrs <- Addr{host, port}
}
if err := sem.Acquire(ctx, 1); err != nil {
break
}
// 等待所有任务完成
close(addrs)
workerWg.Wait()
wg.Wait()
close(scanResults)
resultWg.Wait()
g.Go(func() error {
defer sem.Release(1)
return aliveAddrs
}
// PortConnect 执行单个端口连接检测
// addr: 待检测的地址
// results: 结果通道
// timeout: 超时时间
// wg: 等待组
func PortConnect(addr Addr, results chan<- ScanResult, timeout int64, wg *sync.WaitGroup) {
defer wg.Done()
var isOpen bool
var err error
var conn net.Conn
// 尝试建立TCP连接
conn, err = Common.WrapperTcpWithTimeout("tcp4",
fmt.Sprintf("%s:%v", addr.ip, addr.port),
time.Duration(timeout)*time.Second)
if err == nil {
// 连接测试
conn, err := net.DialTimeout("tcp", addr, to)
if err != nil {
return nil
}
defer conn.Close()
isOpen = true
}
if err != nil || !isOpen {
return
}
// 记录开放端口
address := fmt.Sprintf("%s:%d", addr.ip, addr.port)
Common.LogSuccess(fmt.Sprintf("端口开放 %s", address))
atomic.AddInt64(&count, 1)
aliveMap.Store(addr, struct{}{})
Common.LogSuccess("端口开放 " + addr)
Common.SaveResult(&Common.ScanResult{
Time: time.Now(), Type: Common.PORT, Target: host,
Status: "open", Details: map[string]interface{}{"port": port},
})
// 保存端口扫描结果
portResult := &Common.ScanResult{
Time: time.Now(),
Type: Common.PORT,
Target: addr.ip,
Status: "open",
Details: map[string]interface{}{
"port": addr.port,
},
}
Common.SaveResult(portResult)
// 构造扫描结果
result := ScanResult{
Address: addr.ip,
Port: addr.port,
// 服务识别
if Common.EnableFingerprint {
if info, err := NewPortInfoScanner(host, port, conn, to).Identify(); err == nil {
// 构建结果详情
details := map[string]interface{}{"port": port, "service": info.Name}
if info.Version != "" {
details["version"] = info.Version
}
// 执行服务识别
if Common.EnableFingerprint && conn != nil {
scanner := NewPortInfoScanner(addr.ip, addr.port, conn, time.Duration(timeout)*time.Second)
if serviceInfo, err := scanner.Identify(); err == nil {
result.Service = serviceInfo
// 构造服务识别日志
var logMsg strings.Builder
logMsg.WriteString(fmt.Sprintf("服务识别 %s => ", address))
if serviceInfo.Name != "unknown" {
logMsg.WriteString(fmt.Sprintf("[%s]", serviceInfo.Name))
// 处理额外信息
for k, v := range info.Extras {
if v == "" {
continue
}
if serviceInfo.Version != "" {
logMsg.WriteString(fmt.Sprintf(" 版本:%s", serviceInfo.Version))
}
// 收集服务详细信息
details := map[string]interface{}{
"port": addr.port,
"service": serviceInfo.Name,
}
// 添加版本信息
if serviceInfo.Version != "" {
details["version"] = serviceInfo.Version
}
// 添加产品信息
if v, ok := serviceInfo.Extras["vendor_product"]; ok && v != "" {
switch k {
case "vendor_product":
details["product"] = v
logMsg.WriteString(fmt.Sprintf(" 产品:%s", v))
case "os", "info":
details[k] = v
}
}
if len(info.Banner) > 0 {
details["banner"] = strings.TrimSpace(info.Banner)
}
// 添加操作系统信息
if v, ok := serviceInfo.Extras["os"]; ok && v != "" {
details["os"] = v
logMsg.WriteString(fmt.Sprintf(" 系统:%s", v))
// 保存服务结果
Common.SaveResult(&Common.ScanResult{
Time: time.Now(), Type: Common.SERVICE, Target: host,
Status: "identified", Details: details,
})
// 记录服务信息
var sb strings.Builder
sb.WriteString("服务识别 " + addr + " => ")
if info.Name != "unknown" {
sb.WriteString("[" + info.Name + "]")
}
if info.Version != "" {
sb.WriteString(" 版本:" + info.Version)
}
// 添加额外信息
if v, ok := serviceInfo.Extras["info"]; ok && v != "" {
details["info"] = v
logMsg.WriteString(fmt.Sprintf(" 信息:%s", v))
for k, v := range info.Extras {
if v == "" {
continue
}
// 添加Banner信息
if len(serviceInfo.Banner) > 0 && len(serviceInfo.Banner) < 100 {
details["banner"] = strings.TrimSpace(serviceInfo.Banner)
logMsg.WriteString(fmt.Sprintf(" Banner:[%s]", strings.TrimSpace(serviceInfo.Banner)))
}
// 保存服务识别结果
serviceResult := &Common.ScanResult{
Time: time.Now(),
Type: Common.SERVICE,
Target: addr.ip,
Status: "identified",
Details: details,
}
Common.SaveResult(serviceResult)
Common.LogSuccess(logMsg.String())
switch k {
case "vendor_product":
sb.WriteString(" 产品:" + v)
case "os":
sb.WriteString(" 系统:" + v)
case "info":
sb.WriteString(" 信息:" + v)
}
}
results <- result
}
// NoPortScan 生成端口列表(不进行扫描)
// hostslist: 主机列表
// ports: 端口范围
// 返回地址列表
func NoPortScan(hostslist []string, ports string) []string {
var AliveAddress []string
// 解析并排除端口
probePorts := excludeNoPorts(Common.ParsePort(ports))
// 生成地址列表
for _, port := range probePorts {
for _, host := range hostslist {
address := fmt.Sprintf("%s:%d", host, port)
AliveAddress = append(AliveAddress, address)
}
}
return AliveAddress
}
// excludeNoPorts 排除指定的端口
// ports: 原始端口列表
// 返回过滤后的端口列表
func excludeNoPorts(ports []int) []int {
noPorts := Common.ParsePort(Common.ExcludePorts)
if len(noPorts) == 0 {
return ports
}
// 使用map过滤端口
temp := make(map[int]struct{})
for _, port := range ports {
temp[port] = struct{}{}
}
// 移除需要排除的端口
for _, port := range noPorts {
delete(temp, port)
}
// 转换为有序切片
var newPorts []int
for port := range temp {
newPorts = append(newPorts, port)
}
sort.Ints(newPorts)
return newPorts
if len(info.Banner) > 0 && len(info.Banner) < 100 {
sb.WriteString(" Banner:[" + strings.TrimSpace(info.Banner) + "]")
}
Common.LogSuccess(sb.String())
}
}
return nil
})
}
}
_ = g.Wait()
// 收集结果
var aliveAddrs []string
aliveMap.Range(func(key, _ interface{}) bool {
aliveAddrs = append(aliveAddrs, key.(string))
return true
})
Common.LogSuccess(fmt.Sprintf("扫描完成, 发现 %d 个开放端口", count))
return aliveAddrs
}

View File

@ -92,7 +92,7 @@ func (s *ServiceScanStrategy) discoverAlivePorts(hosts []string) []string {
// 根据扫描模式选择端口扫描方式
if len(hosts) > 0 {
alivePorts = PortScan(hosts, Common.Ports, Common.Timeout)
alivePorts = EnhancedPortScan(hosts, Common.Ports, Common.Timeout)
Common.LogInfo(fmt.Sprintf("存活端口数量: %d", len(alivePorts)))
}

View File

@ -8,9 +8,9 @@ import (
"io"
"net"
"net/http"
"net/url"
"regexp"
"strings"
"sync"
"time"
"unicode/utf8"
@ -20,175 +20,203 @@ import (
"golang.org/x/text/encoding/simplifiedchinese"
)
// 常量定义
const (
maxTitleLength = 100
defaultProtocol = "http"
httpsProtocol = "https"
httpProtocol = "http"
printerFingerPrint = "打印机"
emptyTitle = "\"\""
noTitleText = "无标题"
// HTTP相关常量
httpPort = "80"
httpsPort = "443"
contentEncoding = "Content-Encoding"
gzipEncoding = "gzip"
contentLength = "Content-Length"
)
// 错误定义
var (
ErrNoTitle = fmt.Errorf("无法获取标题")
ErrHTTPClientInit = fmt.Errorf("HTTP客户端未初始化")
ErrReadRespBody = fmt.Errorf("读取响应内容失败")
)
// 响应结果
type WebResponse struct {
Url string
StatusCode int
Title string
Length string
Headers map[string]string
RedirectUrl string
Body []byte
Error error
}
// 协议检测结果
type ProtocolResult struct {
Protocol string
Success bool
}
// WebTitle 获取Web标题和指纹信息
func WebTitle(info *Common.HostInfo) error {
Common.LogDebug(fmt.Sprintf("开始获取Web标题初始信息: %+v", info))
if info == nil {
return fmt.Errorf("主机信息为空")
}
// 初始化Url
if err := initializeUrl(info); err != nil {
Common.LogError(fmt.Sprintf("初始化Url失败: %v", err))
return err
}
// 获取网站标题信息
err, CheckData := GOWebTitle(info)
Common.LogDebug(fmt.Sprintf("GOWebTitle执行完成 - 错误: %v, 检查数据长度: %d", err, len(CheckData)))
checkData, err := fetchWebInfo(info)
if err != nil {
// 记录错误但继续处理可能获取的数据
Common.LogError(fmt.Sprintf("获取网站信息失败: %s %v", info.Url, err))
}
info.Infostr = WebScan.InfoCheck(info.Url, &CheckData)
Common.LogDebug(fmt.Sprintf("信息检查完成,获得信息: %v", info.Infostr))
// 分析指纹
if len(checkData) > 0 {
info.Infostr = WebScan.InfoCheck(info.Url, &checkData)
// 检查是否为打印机,避免意外打印
for _, v := range info.Infostr {
if v == "打印机" {
Common.LogDebug("检测到打印机,停止扫描")
if v == printerFingerPrint {
Common.LogInfo("检测到打印机,停止扫描")
return nil
}
}
// 输出错误信息(如果有)
if err != nil {
errlog := fmt.Sprintf("网站标题 %v %v", info.Url, err)
Common.LogError(errlog)
}
return err
}
// GOWebTitle 获取网站标题并处理URL增强错误处理和协议切换
func GOWebTitle(info *Common.HostInfo) (err error, CheckData []WebScan.CheckDatas) {
Common.LogDebug(fmt.Sprintf("开始处理URL: %s", info.Url))
// 如果URL未指定根据端口生成URL
// 初始化Url根据主机和端口生成完整Url
func initializeUrl(info *Common.HostInfo) error {
if info.Url == "" {
Common.LogDebug("URL为空根据端口生成URL")
// 根据端口推断Url
switch info.Ports {
case "80":
info.Url = fmt.Sprintf("http://%s", info.Host)
case "443":
info.Url = fmt.Sprintf("https://%s", info.Host)
case httpPort:
info.Url = fmt.Sprintf("%s://%s", httpProtocol, info.Host)
case httpsPort:
info.Url = fmt.Sprintf("%s://%s", httpsProtocol, info.Host)
default:
host := fmt.Sprintf("%s:%s", info.Host, info.Ports)
Common.LogDebug(fmt.Sprintf("正在检测主机协议: %s", host))
protocol := GetProtocol(host, Common.Timeout)
Common.LogDebug(fmt.Sprintf("检测到协议: %s", protocol))
protocol, err := detectProtocol(host, Common.Timeout)
if err != nil {
return fmt.Errorf("协议检测失败: %w", err)
}
info.Url = fmt.Sprintf("%s://%s:%s", protocol, info.Host, info.Ports)
}
} else {
// 处理未指定协议的URL
if !strings.Contains(info.Url, "://") {
Common.LogDebug("URL未包含协议开始检测")
} else if !strings.Contains(info.Url, "://") {
// 处理未指定协议的Url
host := strings.Split(info.Url, "/")[0]
protocol := GetProtocol(host, Common.Timeout)
Common.LogDebug(fmt.Sprintf("检测到协议: %s", protocol))
protocol, err := detectProtocol(host, Common.Timeout)
if err != nil {
return fmt.Errorf("协议检测失败: %w", err)
}
info.Url = fmt.Sprintf("%s://%s", protocol, info.Url)
}
}
Common.LogDebug(fmt.Sprintf("协议检测完成后的URL: %s", info.Url))
// 记录原始URL协议
originalProtocol := "http"
if strings.HasPrefix(info.Url, "https://") {
originalProtocol = "https"
}
// 第一次获取URL
Common.LogDebug("第一次尝试访问URL")
err, result, CheckData := geturl(info, 1, CheckData)
Common.LogDebug(fmt.Sprintf("第一次访问结果 - 错误: %v, 返回信息: %s", err, result))
// 如果访问失败并且使用的是HTTPS尝试降级到HTTP
if err != nil && !strings.Contains(err.Error(), "EOF") {
if originalProtocol == "https" {
Common.LogDebug("HTTPS访问失败尝试降级到HTTP")
// 替换协议部分
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
Common.LogDebug(fmt.Sprintf("降级后的URL: %s", info.Url))
err, result, CheckData = geturl(info, 1, CheckData)
Common.LogDebug(fmt.Sprintf("HTTP降级访问结果 - 错误: %v, 返回信息: %s", err, result))
// 如果仍然失败,返回错误
if err != nil && !strings.Contains(err.Error(), "EOF") {
return
}
} else {
// 如果本来就是HTTP并且失败了直接返回错误
return
}
}
// 处理URL跳转
if strings.Contains(result, "://") {
Common.LogDebug(fmt.Sprintf("检测到重定向到: %s", result))
info.Url = result
err, result, CheckData = geturl(info, 3, CheckData)
Common.LogDebug(fmt.Sprintf("重定向请求结果 - 错误: %v, 返回信息: %s", err, result))
if err != nil {
// 如果重定向跟踪失败,尝试降级协议
if strings.HasPrefix(info.Url, "https://") {
Common.LogDebug("重定向HTTPS访问失败尝试降级到HTTP")
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
err, result, CheckData = geturl(info, 3, CheckData)
Common.LogDebug(fmt.Sprintf("重定向降级访问结果 - 错误: %v, 返回信息: %s", err, result))
}
if err != nil {
return
}
}
}
// 处理HTTP到HTTPS的升级提示
if result == "https" && !strings.HasPrefix(info.Url, "https://") {
Common.LogDebug("正在升级到HTTPS")
info.Url = strings.Replace(info.Url, "http://", "https://", 1)
Common.LogDebug(fmt.Sprintf("升级后的URL: %s", info.Url))
err, result, CheckData = geturl(info, 1, CheckData)
Common.LogDebug(fmt.Sprintf("HTTPS升级访问结果 - 错误: %v, 返回信息: %s", err, result))
// 如果HTTPS升级后访问失败回退到HTTP
if err != nil && !strings.Contains(err.Error(), "EOF") {
Common.LogDebug("HTTPS升级访问失败回退到HTTP")
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
err, result, CheckData = geturl(info, 1, CheckData)
Common.LogDebug(fmt.Sprintf("回退到HTTP访问结果 - 错误: %v, 返回信息: %s", err, result))
}
// 处理升级后的跳转
if strings.Contains(result, "://") {
Common.LogDebug(fmt.Sprintf("协议升级后发现重定向到: %s", result))
info.Url = result
err, _, CheckData = geturl(info, 3, CheckData)
if err != nil {
// 如果重定向跟踪失败,再次尝试降级
if strings.HasPrefix(info.Url, "https://") {
Common.LogDebug("升级后重定向HTTPS访问失败尝试降级到HTTP")
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
err, _, CheckData = geturl(info, 3, CheckData)
}
}
}
}
Common.LogDebug(fmt.Sprintf("GOWebTitle执行完成 - 错误: %v", err))
return
return nil
}
func geturl(info *Common.HostInfo, flag int, CheckData []WebScan.CheckDatas) (error, string, []WebScan.CheckDatas) {
Common.LogDebug(fmt.Sprintf("geturl开始执行 - URL: %s, 标志位: %d", info.Url, flag))
// 获取Web信息标题、指纹等
func fetchWebInfo(info *Common.HostInfo) ([]WebScan.CheckDatas, error) {
var checkData []WebScan.CheckDatas
// 处理目标URL
Url := info.Url
if flag == 2 {
Common.LogDebug("处理favicon.ico URL")
URL, err := url.Parse(Url)
if err == nil {
Url = fmt.Sprintf("%s://%s/favicon.ico", URL.Scheme, URL.Host)
} else {
Url += "/favicon.ico"
}
Common.LogDebug(fmt.Sprintf("favicon URL: %s", Url))
}
// 记录原始Url协议
originalUrl := info.Url
isHTTPS := strings.HasPrefix(info.Url, "https://")
// 创建HTTP请求
Common.LogDebug("开始创建HTTP请求")
req, err := http.NewRequest("GET", Url, nil)
// 第一次尝试访问Url
resp, err := fetchUrlWithRetry(info, false, &checkData)
// 处理不同的错误情况
if err != nil {
Common.LogDebug(fmt.Sprintf("创建HTTP请求失败: %v", err))
return err, "", CheckData
// 如果是HTTPS并失败尝试降级到HTTP
if isHTTPS {
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
resp, err = fetchUrlWithRetry(info, false, &checkData)
// 如果HTTP也失败恢复原始Url并返回错误
if err != nil {
info.Url = originalUrl
return checkData, err
}
} else {
return checkData, err
}
}
// 处理重定向
if resp != nil && resp.RedirectUrl != "" {
info.Url = resp.RedirectUrl
resp, err = fetchUrlWithRetry(info, true, &checkData)
// 如果重定向后失败,尝试降级协议
if err != nil && strings.HasPrefix(info.Url, "https://") {
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
resp, err = fetchUrlWithRetry(info, true, &checkData)
}
}
// 处理需要升级到HTTPS的情况
if resp != nil && resp.StatusCode == 400 && !strings.HasPrefix(info.Url, "https://") {
info.Url = strings.Replace(info.Url, "http://", "https://", 1)
resp, err = fetchUrlWithRetry(info, false, &checkData)
// 如果HTTPS升级失败回退到HTTP
if err != nil {
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
resp, err = fetchUrlWithRetry(info, false, &checkData)
}
// 处理升级后的重定向
if resp != nil && resp.RedirectUrl != "" {
info.Url = resp.RedirectUrl
resp, err = fetchUrlWithRetry(info, true, &checkData)
}
}
return checkData, err
}
// 尝试获取Url支持重试
func fetchUrlWithRetry(info *Common.HostInfo, followRedirect bool, checkData *[]WebScan.CheckDatas) (*WebResponse, error) {
// 获取页面内容
resp, err := fetchUrl(info.Url, followRedirect)
if err != nil {
return nil, err
}
// 保存检查数据
if resp.Body != nil && len(resp.Body) > 0 {
headers := fmt.Sprintf("%v", resp.Headers)
*checkData = append(*checkData, WebScan.CheckDatas{resp.Body, headers})
}
// 保存扫描结果
if resp.StatusCode > 0 {
saveWebResult(info, resp)
}
return resp, nil
}
// 抓取Url内容
func fetchUrl(targetUrl string, followRedirect bool) (*WebResponse, error) {
// 创建HTTP请求
req, err := http.NewRequest("GET", targetUrl, nil)
if err != nil {
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
}
// 设置请求头
@ -199,90 +227,152 @@ func geturl(info *Common.HostInfo, flag int, CheckData []WebScan.CheckDatas) (er
req.Header.Set("Cookie", Common.Cookie)
}
req.Header.Set("Connection", "close")
Common.LogDebug("已设置请求头")
// 选择HTTP客户端
var client *http.Client
if flag == 1 {
client = lib.ClientNoRedirect
Common.LogDebug("使用不跟随重定向的客户端")
} else {
if followRedirect {
client = lib.Client
Common.LogDebug("使用普通客户端")
} else {
client = lib.ClientNoRedirect
}
// 检查客户端是否为空
if client == nil {
Common.LogDebug("错误: HTTP客户端为空")
return fmt.Errorf("HTTP客户端未初始化"), "", CheckData
return nil, ErrHTTPClientInit
}
// 发送请求
Common.LogDebug("开始发送HTTP请求")
resp, err := client.Do(req)
if err != nil {
Common.LogDebug(fmt.Sprintf("HTTP请求失败: %v", err))
return err, "https", CheckData
// 特殊处理SSL/TLS相关错误
errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, "tls") || strings.Contains(errMsg, "ssl") ||
strings.Contains(errMsg, "handshake") || strings.Contains(errMsg, "certificate") {
return &WebResponse{Error: err}, nil
}
return nil, err
}
defer resp.Body.Close()
Common.LogDebug(fmt.Sprintf("收到HTTP响应状态码: %d", resp.StatusCode))
// 准备响应结果
result := &WebResponse{
Url: req.URL.String(),
StatusCode: resp.StatusCode,
Headers: make(map[string]string),
}
// 提取响应头
for k, v := range resp.Header {
if len(v) > 0 {
result.Headers[k] = v[0]
}
}
// 获取内容长度
result.Length = resp.Header.Get(contentLength)
// 检查重定向
redirectUrl, err := resp.Location()
if err == nil {
result.RedirectUrl = redirectUrl.String()
}
// 读取响应内容
body, err := getRespBody(resp)
body, err := readResponseBody(resp)
if err != nil {
Common.LogDebug(fmt.Sprintf("读取响应内容失败: %v", err))
return err, "https", CheckData
return result, fmt.Errorf("读取响应内容失败: %w", err)
}
Common.LogDebug(fmt.Sprintf("成功读取响应内容,长度: %d", len(body)))
result.Body = body
// 保存检查数据
CheckData = append(CheckData, WebScan.CheckDatas{body, fmt.Sprintf("%s", resp.Header)})
Common.LogDebug("已保存检查数据")
// 处理非favicon请求
var reurl string
if flag != 2 {
// 处理编码
// 提取标题
if !utf8.Valid(body) {
body, _ = simplifiedchinese.GBK.NewDecoder().Bytes(body)
}
result.Title = extractTitle(body)
// 获取页面信息
title := gettitle(body)
length := resp.Header.Get("Content-Length")
if length == "" {
length = fmt.Sprintf("%v", len(body))
if result.Length == "" {
result.Length = fmt.Sprintf("%d", len(body))
}
// 收集服务器信息
serverInfo := make(map[string]interface{})
serverInfo["title"] = title
serverInfo["length"] = length
serverInfo["status_code"] = resp.StatusCode
return result, nil
}
// 收集响应头信息
for k, v := range resp.Header {
if len(v) > 0 {
serverInfo[strings.ToLower(k)] = v[0]
// 读取HTTP响应体内容
func readResponseBody(resp *http.Response) ([]byte, error) {
var body []byte
var reader io.Reader = resp.Body
// 处理gzip压缩的响应
if resp.Header.Get(contentEncoding) == gzipEncoding {
gr, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, fmt.Errorf("创建gzip解压器失败: %w", err)
}
defer gr.Close()
reader = gr
}
// 检查重定向
redirURL, err1 := resp.Location()
if err1 == nil {
reurl = redirURL.String()
serverInfo["redirect_url"] = reurl
// 读取内容
body, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("读取响应内容失败: %w", err)
}
// 处理指纹信息 - 添加调试日志
Common.LogDebug(fmt.Sprintf("保存结果前的指纹信息: %v", info.Infostr))
return body, nil
}
// 处理空指纹情况
// 提取网页标题
func extractTitle(body []byte) string {
// 使用正则表达式匹配title标签内容
re := regexp.MustCompile("(?ims)<title.*?>(.*?)</title>")
find := re.FindSubmatch(body)
if len(find) > 1 {
title := string(find[1])
// 清理标题内容
title = strings.TrimSpace(title)
title = strings.Replace(title, "\n", "", -1)
title = strings.Replace(title, "\r", "", -1)
title = strings.Replace(title, "&nbsp;", " ", -1)
// 截断过长的标题
if len(title) > maxTitleLength {
title = title[:maxTitleLength]
}
// 处理空标题
if title == "" {
return emptyTitle
}
return title
}
return noTitleText
}
// 保存Web扫描结果
func saveWebResult(info *Common.HostInfo, resp *WebResponse) {
// 处理指纹信息
fingerprints := info.Infostr
if len(fingerprints) == 1 && fingerprints[0] == "" {
// 如果是只包含空字符串的数组,替换为空数组
fingerprints = []string{}
Common.LogDebug("检测到空指纹,已转换为空数组")
}
// 准备服务器信息
serverInfo := make(map[string]interface{})
serverInfo["title"] = resp.Title
serverInfo["length"] = resp.Length
serverInfo["status_code"] = resp.StatusCode
// 添加响应头信息
for k, v := range resp.Headers {
serverInfo[strings.ToLower(k)] = v
}
// 添加重定向信息
if resp.RedirectUrl != "" {
serverInfo["redirect_Url"] = resp.RedirectUrl
}
// 保存扫描结果
@ -294,283 +384,170 @@ func geturl(info *Common.HostInfo, flag int, CheckData []WebScan.CheckDatas) (er
Details: map[string]interface{}{
"port": info.Ports,
"service": "http",
"title": title,
"url": resp.Request.URL.String(),
"title": resp.Title,
"Url": resp.Url,
"status_code": resp.StatusCode,
"length": length,
"length": resp.Length,
"server_info": serverInfo,
"fingerprints": fingerprints, // 使用处理过的指纹信息
"fingerprints": fingerprints,
},
}
Common.SaveResult(result)
Common.LogDebug(fmt.Sprintf("已保存结果,指纹信息: %v", fingerprints))
// 输出控制台日志
logMsg := fmt.Sprintf("网站标题 %-25v 状态码:%-3v 长度:%-6v 标题:%v",
resp.Request.URL, resp.StatusCode, length, title)
if reurl != "" {
logMsg += fmt.Sprintf(" 重定向地址: %s", reurl)
resp.Url, resp.StatusCode, resp.Length, resp.Title)
if resp.RedirectUrl != "" {
logMsg += fmt.Sprintf(" 重定向地址: %s", resp.RedirectUrl)
}
// 添加指纹信息到控制台日志
if len(fingerprints) > 0 {
logMsg += fmt.Sprintf(" 指纹:%v", fingerprints)
}
Common.LogSuccess(logMsg)
}
// 返回结果
if reurl != "" {
Common.LogDebug(fmt.Sprintf("返回重定向URL: %s", reurl))
return nil, reurl, CheckData
}
if resp.StatusCode == 400 && !strings.HasPrefix(info.Url, "https") {
Common.LogDebug("返回HTTPS升级标志")
return nil, "https", CheckData
}
Common.LogDebug("geturl执行完成无特殊返回")
return nil, "", CheckData
}
// getRespBody 读取HTTP响应体内容
func getRespBody(oResp *http.Response) ([]byte, error) {
Common.LogDebug("开始读取响应体内容")
var body []byte
// 处理gzip压缩的响应
if oResp.Header.Get("Content-Encoding") == "gzip" {
Common.LogDebug("检测到gzip压缩开始解压")
gr, err := gzip.NewReader(oResp.Body)
if err != nil {
Common.LogDebug(fmt.Sprintf("创建gzip解压器失败: %v", err))
return nil, err
}
defer gr.Close()
// 循环读取解压内容
for {
buf := make([]byte, 1024)
n, err := gr.Read(buf)
if err != nil && err != io.EOF {
Common.LogDebug(fmt.Sprintf("读取压缩内容失败: %v", err))
return nil, err
}
if n == 0 {
break
}
body = append(body, buf...)
}
Common.LogDebug(fmt.Sprintf("gzip解压完成内容长度: %d", len(body)))
} else {
// 直接读取未压缩的响应
Common.LogDebug("读取未压缩的响应内容")
raw, err := io.ReadAll(oResp.Body)
if err != nil {
Common.LogDebug(fmt.Sprintf("读取响应内容失败: %v", err))
return nil, err
}
body = raw
Common.LogDebug(fmt.Sprintf("读取完成,内容长度: %d", len(body)))
}
return body, nil
}
// gettitle 从HTML内容中提取网页标题
func gettitle(body []byte) (title string) {
Common.LogDebug("开始提取网页标题")
// 使用正则表达式匹配title标签内容
re := regexp.MustCompile("(?ims)<title.*?>(.*?)</title>")
find := re.FindSubmatch(body)
if len(find) > 1 {
title = string(find[1])
Common.LogDebug(fmt.Sprintf("找到原始标题: %s", title))
// 清理标题内容
title = strings.TrimSpace(title) // 去除首尾空格
title = strings.Replace(title, "\n", "", -1) // 去除换行
title = strings.Replace(title, "\r", "", -1) // 去除回车
title = strings.Replace(title, "&nbsp;", " ", -1) // 替换HTML空格
// 截断过长的标题
if len(title) > 100 {
Common.LogDebug("标题超过100字符进行截断")
title = title[:100]
// 检测目标主机的协议类型(HTTP/HTTPS)
func detectProtocol(host string, timeout int64) (string, error) {
// 根据标准端口快速判断协议
if strings.HasSuffix(host, ":"+httpPort) {
return httpProtocol, nil
} else if strings.HasSuffix(host, ":"+httpsPort) {
return httpsProtocol, nil
}
// 处理空标题
if title == "" {
Common.LogDebug("标题为空,使用双引号代替")
title = "\"\""
}
} else {
Common.LogDebug("未找到标题标签")
title = "无标题"
}
Common.LogDebug(fmt.Sprintf("最终标题: %s", title))
return
}
// GetProtocol 检测目标主机的协议类型(HTTP/HTTPS),优先返回可用的协议
func GetProtocol(host string, Timeout int64) (protocol string) {
Common.LogDebug(fmt.Sprintf("开始检测主机协议 - 主机: %s, 超时: %d秒", host, Timeout))
// 默认使用http协议
protocol = "http"
timeoutDuration := time.Duration(Timeout) * time.Second
timeoutDuration := time.Duration(timeout) * time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
defer cancel()
// 1. 根据标准端口快速判断协议
if strings.HasSuffix(host, ":80") {
Common.LogDebug("检测到标准HTTP端口使用HTTP协议")
return "http"
} else if strings.HasSuffix(host, ":443") {
Common.LogDebug("检测到标准HTTPS端口使用HTTPS协议")
return "https"
}
// 2. 并发检测HTTP和HTTPS
type protocolResult struct {
name string
success bool
}
resultChan := make(chan protocolResult, 2)
singleTimeout := timeoutDuration / 2 // 每个协议检测的超时时间减半
// 并发检测HTTP和HTTPS
resultChan := make(chan ProtocolResult, 2)
wg := sync.WaitGroup{}
wg.Add(2)
// 检测HTTPS
go func() {
Common.LogDebug("开始检测HTTPS协议")
defer wg.Done()
success := checkHTTPS(host, timeoutDuration/2)
select {
case resultChan <- ProtocolResult{httpsProtocol, success}:
case <-ctx.Done():
}
}()
// 检测HTTP
go func() {
defer wg.Done()
success := checkHTTP(ctx, host, timeoutDuration/2)
select {
case resultChan <- ProtocolResult{httpProtocol, success}:
case <-ctx.Done():
}
}()
// 确保所有goroutine正常退出
go func() {
wg.Wait()
close(resultChan)
}()
// 收集结果
var httpsResult, httpResult *ProtocolResult
for result := range resultChan {
if result.Protocol == httpsProtocol {
r := result
httpsResult = &r
} else if result.Protocol == httpProtocol {
r := result
httpResult = &r
}
}
// 决定使用哪种协议 - 优先使用HTTPS
if httpsResult != nil && httpsResult.Success {
return httpsProtocol, nil
} else if httpResult != nil && httpResult.Success {
return httpProtocol, nil
}
// 默认使用HTTP
return defaultProtocol, nil
}
// 检测HTTPS协议
func checkHTTPS(host string, timeout time.Duration) bool {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS10,
}
dialer := &net.Dialer{
Timeout: singleTimeout,
Timeout: timeout,
}
conn, err := tls.DialWithDialer(dialer, "tcp", host, tlsConfig)
if err == nil {
Common.LogDebug("HTTPS连接成功")
conn.Close()
resultChan <- protocolResult{"https", true}
return
return true
}
// 分析TLS错误
if err != nil {
// 分析TLS错误某些错误可能表明服务器支持TLS但有其他问题
errMsg := strings.ToLower(err.Error())
// 这些错误可能表明服务器确实支持TLS但有其他问题
if strings.Contains(errMsg, "handshake failure") ||
return strings.Contains(errMsg, "handshake failure") ||
strings.Contains(errMsg, "certificate") ||
strings.Contains(errMsg, "tls") ||
strings.Contains(errMsg, "x509") ||
strings.Contains(errMsg, "secure") {
Common.LogDebug(fmt.Sprintf("TLS握手有错误但可能是HTTPS协议: %v", err))
resultChan <- protocolResult{"https", true}
return
}
Common.LogDebug(fmt.Sprintf("HTTPS连接失败: %v", err))
}
resultChan <- protocolResult{"https", false}
}()
strings.Contains(errMsg, "secure")
}
// 检测HTTP
go func() {
Common.LogDebug("开始检测HTTP协议")
// 检测HTTP协议
func checkHTTP(ctx context.Context, host string, timeout time.Duration) bool {
req, err := http.NewRequestWithContext(ctx, "HEAD", fmt.Sprintf("http://%s", host), nil)
if err != nil {
Common.LogDebug(fmt.Sprintf("创建HTTP请求失败: %v", err))
resultChan <- protocolResult{"http", false}
return
return false
}
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
DialContext: (&net.Dialer{
Timeout: singleTimeout,
Timeout: timeout,
}).DialContext,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // 不跟随重定向
},
Timeout: singleTimeout,
Timeout: timeout,
}
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
Common.LogDebug(fmt.Sprintf("HTTP连接成功状态码: %d", resp.StatusCode))
resultChan <- protocolResult{"http", true}
return
return true
}
Common.LogDebug(fmt.Sprintf("标准HTTP请求失败: %v尝试原始TCP连接", err))
// 尝试原始TCP连接和简单HTTP请求
netConn, err := net.DialTimeout("tcp", host, singleTimeout)
netConn, err := net.DialTimeout("tcp", host, timeout)
if err == nil {
defer netConn.Close()
netConn.SetDeadline(time.Now().Add(singleTimeout))
netConn.SetDeadline(time.Now().Add(timeout))
// 发送简单HTTP请求
_, err = netConn.Write([]byte("HEAD / HTTP/1.0\r\nHost: " + host + "\r\n\r\n"))
if err == nil {
// 读取响应
buf := make([]byte, 1024)
netConn.SetDeadline(time.Now().Add(singleTimeout))
netConn.SetDeadline(time.Now().Add(timeout))
n, err := netConn.Read(buf)
if err == nil && n > 0 {
response := string(buf[:n])
if strings.Contains(response, "HTTP/") {
Common.LogDebug("通过原始TCP连接确认HTTP协议")
resultChan <- protocolResult{"http", true}
return
return strings.Contains(response, "HTTP/")
}
}
}
Common.LogDebug("原始TCP连接成功但HTTP响应无效")
} else {
Common.LogDebug(fmt.Sprintf("原始TCP连接失败: %v", err))
}
resultChan <- protocolResult{"http", false}
}()
// 3. 收集结果并决定使用哪种协议
var httpsSuccess, httpSuccess bool
// 等待两个goroutine返回结果或超时
for i := 0; i < 2; i++ {
select {
case result := <-resultChan:
if result.name == "https" {
httpsSuccess = result.success
Common.LogDebug(fmt.Sprintf("HTTPS检测结果: %v", httpsSuccess))
} else if result.name == "http" {
httpSuccess = result.success
Common.LogDebug(fmt.Sprintf("HTTP检测结果: %v", httpSuccess))
}
case <-ctx.Done():
Common.LogDebug("协议检测超时")
break
}
}
// 4. 决定使用哪种协议 - 优先使用HTTPS如果HTTPS不可用则使用HTTP
if httpsSuccess {
Common.LogDebug("选择使用HTTPS协议")
return "https"
} else if httpSuccess {
Common.LogDebug("选择使用HTTP协议")
return "http"
}
// 5. 如果两种协议都无法确认,保持默认值
Common.LogDebug(fmt.Sprintf("无法确定协议,使用默认协议: %s", protocol))
return
return false
}

View File

@ -1,90 +1,171 @@
package WebScan
import (
"context"
"embed"
"errors"
"fmt"
"github.com/shadow1ng/fscan/Common"
"github.com/shadow1ng/fscan/WebScan/lib"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/shadow1ng/fscan/Common"
"github.com/shadow1ng/fscan/WebScan/lib"
)
// 常量定义
const (
protocolHTTP = "http://"
protocolHTTPS = "https://"
yamlExt = ".yaml"
ymlExt = ".yml"
defaultTimeout = 30 * time.Second
concurrencyLimit = 10 // 并发加载POC的限制
)
// 错误定义
var (
ErrInvalidURL = errors.New("无效的URL格式")
ErrEmptyTarget = errors.New("目标URL为空")
ErrPocNotFound = errors.New("未找到匹配的POC")
ErrPocLoadFailed = errors.New("POC加载失败")
)
//go:embed pocs
var Pocs embed.FS
var once sync.Once
var AllPocs []*lib.Poc
var pocsFS embed.FS
var (
once sync.Once
allPocs []*lib.Poc
)
// WebScan 执行Web漏洞扫描
func WebScan(info *Common.HostInfo) {
once.Do(initpoc)
// 初始化POC
once.Do(initPocs)
var pocinfo = Common.Pocinfo
// 自动构建URL
if info.Url == "" {
info.Url = fmt.Sprintf("http://%s:%s", info.Host, info.Ports)
// 验证输入
if info == nil {
Common.LogError("无效的扫描目标")
return
}
urlParts := strings.Split(info.Url, "/")
// 检查切片长度并构建目标URL
if len(urlParts) >= 3 {
pocinfo.Target = strings.Join(urlParts[:3], "/")
} else {
pocinfo.Target = info.Url
if len(allPocs) == 0 {
Common.LogError("POC加载失败无法执行扫描")
return
}
Common.LogDebug(fmt.Sprintf("扫描目标: %s", pocinfo.Target))
// 构建目标URL
target, err := buildTargetURL(info)
if err != nil {
Common.LogError(fmt.Sprintf("构建目标URL失败: %v", err))
return
}
// 如果是直接调用WebPoc没有指定pocName执行所有POC
if pocinfo.PocName == "" && len(info.Infostr) == 0 {
Common.LogDebug("直接调用WebPoc执行所有POC")
Execute(pocinfo)
} else {
// 根据指纹信息选择性执行POC
if len(info.Infostr) > 0 {
for _, infostr := range info.Infostr {
pocinfo.PocName = lib.CheckInfoPoc(infostr)
if pocinfo.PocName != "" {
Common.LogDebug(fmt.Sprintf("根据指纹 %s 执行对应POC", infostr))
Execute(pocinfo)
}
}
} else if pocinfo.PocName != "" {
// 指定了特定的POC
Common.LogDebug(fmt.Sprintf("执行指定POC: %s", pocinfo.PocName))
Execute(pocinfo)
}
// 使用带超时的上下文
ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
defer cancel()
// 根据扫描策略执行POC
if Common.Pocinfo.PocName == "" && len(info.Infostr) == 0 {
// 执行所有POC
executePOCs(ctx, Common.PocInfo{Target: target})
} else if len(info.Infostr) > 0 {
// 基于指纹信息执行POC
scanByFingerprints(ctx, target, info.Infostr)
} else if Common.Pocinfo.PocName != "" {
// 基于指定POC名称执行
executePOCs(ctx, Common.PocInfo{Target: target, PocName: Common.Pocinfo.PocName})
}
}
// Execute 执行具体的POC检测
func Execute(PocInfo Common.PocInfo) {
Common.LogDebug(fmt.Sprintf("开始执行POC检测目标: %s", PocInfo.Target))
// buildTargetURL 构建规范的目标URL
func buildTargetURL(info *Common.HostInfo) (string, error) {
// 自动构建URL
if info.Url == "" {
info.Url = fmt.Sprintf("%s%s:%s", protocolHTTP, info.Host, info.Ports)
} else if !hasProtocolPrefix(info.Url) {
info.Url = protocolHTTP + info.Url
}
// 解析URL以提取基础部分
parsedURL, err := url.Parse(info.Url)
if err != nil {
return "", fmt.Errorf("%w: %v", ErrInvalidURL, err)
}
return fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host), nil
}
// hasProtocolPrefix 检查URL是否包含协议前缀
func hasProtocolPrefix(urlStr string) bool {
return strings.HasPrefix(urlStr, protocolHTTP) || strings.HasPrefix(urlStr, protocolHTTPS)
}
// scanByFingerprints 根据指纹执行POC
func scanByFingerprints(ctx context.Context, target string, fingerprints []string) {
for _, fingerprint := range fingerprints {
if fingerprint == "" {
continue
}
pocName := lib.CheckInfoPoc(fingerprint)
if pocName == "" {
continue
}
executePOCs(ctx, Common.PocInfo{Target: target, PocName: pocName})
}
}
// executePOCs 执行POC检测
func executePOCs(ctx context.Context, pocInfo Common.PocInfo) {
// 验证目标
if pocInfo.Target == "" {
Common.LogError(ErrEmptyTarget.Error())
return
}
// 确保URL格式正确
if !strings.HasPrefix(PocInfo.Target, "http://") && !strings.HasPrefix(PocInfo.Target, "https://") {
PocInfo.Target = "http://" + PocInfo.Target
if !hasProtocolPrefix(pocInfo.Target) {
pocInfo.Target = protocolHTTP + pocInfo.Target
}
// 验证URL格式
_, err := url.Parse(PocInfo.Target)
// 验证URL
_, err := url.Parse(pocInfo.Target)
if err != nil {
Common.LogError(fmt.Sprintf("无效的URL格式 %v: %v", PocInfo.Target, err))
Common.LogError(fmt.Sprintf("%v %s: %v", ErrInvalidURL, pocInfo.Target, err))
return
}
// 创建基础HTTP请求
req, err := http.NewRequest("GET", PocInfo.Target, nil)
// 创建基础请求
req, err := createBaseRequest(ctx, pocInfo.Target)
if err != nil {
Common.LogError(fmt.Sprintf("初始化请求失败 %v: %v", PocInfo.Target, err))
Common.LogError(fmt.Sprintf("创建HTTP请求失败: %v", err))
return
}
// 筛选POC
matchedPocs := filterPocs(pocInfo.PocName)
if len(matchedPocs) == 0 {
Common.LogDebug(fmt.Sprintf("%v: %s", ErrPocNotFound, pocInfo.PocName))
return
}
// 执行POC检测
lib.CheckMultiPoc(req, matchedPocs, Common.PocNum)
}
// createBaseRequest 创建带上下文的HTTP请求
func createBaseRequest(ctx context.Context, target string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, "GET", target, nil)
if err != nil {
return nil, err
}
// 设置请求头
req.Header.Set("User-agent", Common.UserAgent)
req.Header.Set("Accept", Common.Accept)
@ -93,75 +174,150 @@ func Execute(PocInfo Common.PocInfo) {
req.Header.Set("Cookie", Common.Cookie)
}
// 根据名称筛选POC并执行
pocs := filterPoc(PocInfo.PocName)
Common.LogDebug(fmt.Sprintf("筛选到的POC数量: %d", len(pocs)))
lib.CheckMultiPoc(req, pocs, Common.PocNum)
return req, nil
}
// initpoc 初始化POC加载
func initpoc() {
Common.LogDebug("开始初始化POC")
// initPocs 初始化并加载POC
func initPocs() {
allPocs = make([]*lib.Poc, 0)
if Common.PocPath == "" {
Common.LogDebug("从内置目录加载POC")
// 从嵌入的POC目录加载
entries, err := Pocs.ReadDir("pocs")
loadEmbeddedPocs()
} else {
loadExternalPocs(Common.PocPath)
}
}
// loadEmbeddedPocs 加载内置POC
func loadEmbeddedPocs() {
entries, err := pocsFS.ReadDir("pocs")
if err != nil {
Common.LogError(fmt.Sprintf("加载内置POC失败: %v", err))
Common.LogError(fmt.Sprintf("加载内置POC目录失败: %v", err))
return
}
// 加载YAML格式的POC文件
// 收集所有POC文件
var pocFiles []string
for _, entry := range entries {
filename := entry.Name()
if strings.HasSuffix(filename, ".yaml") || strings.HasSuffix(filename, ".yml") {
if poc, err := lib.LoadPoc(filename, Pocs); err == nil && poc != nil {
AllPocs = append(AllPocs, poc)
} else if err != nil {
if isPocFile(entry.Name()) {
pocFiles = append(pocFiles, entry.Name())
}
}
}
Common.LogDebug(fmt.Sprintf("内置POC加载完成共加载 %d 个", len(AllPocs)))
} else {
// 从指定目录加载POC
Common.LogSuccess(fmt.Sprintf("从目录加载POC: %s", Common.PocPath))
err := filepath.Walk(Common.PocPath, func(path string, info os.FileInfo, err error) error {
if err != nil || info == nil {
return err
}
if !info.IsDir() && (strings.HasSuffix(path, ".yaml") || strings.HasSuffix(path, ".yml")) {
if poc, err := lib.LoadPocbyPath(path); err == nil && poc != nil {
AllPocs = append(AllPocs, poc)
} else if err != nil {
// 并发加载POC文件
loadPocsConcurrently(pocFiles, true, "")
}
// loadExternalPocs 从外部路径加载POC
func loadExternalPocs(pocPath string) {
if !directoryExists(pocPath) {
Common.LogError(fmt.Sprintf("POC目录不存在: %s", pocPath))
return
}
// 收集所有POC文件路径
var pocFiles []string
err := filepath.Walk(pocPath, func(path string, info os.FileInfo, err error) error {
if err != nil || info == nil || info.IsDir() {
return nil
}
if isPocFile(info.Name()) {
pocFiles = append(pocFiles, path)
}
return nil
})
if err != nil {
Common.LogError(fmt.Sprintf("加载外部POC失败: %v", err))
}
Common.LogDebug(fmt.Sprintf("外部POC加载完成共加载 %d 个", len(AllPocs)))
Common.LogError(fmt.Sprintf("遍历POC目录失败: %v", err))
return
}
// 并发加载POC文件
loadPocsConcurrently(pocFiles, false, pocPath)
}
// filterPoc 根据POC名称筛选
func filterPoc(pocname string) []*lib.Poc {
Common.LogDebug(fmt.Sprintf("开始筛选POC筛选条件: %s", pocname))
if pocname == "" {
Common.LogDebug(fmt.Sprintf("未指定POC名称返回所有POC: %d 个", len(AllPocs)))
return AllPocs
// loadPocsConcurrently 并发加载POC文件
func loadPocsConcurrently(pocFiles []string, isEmbedded bool, pocPath string) {
pocCount := len(pocFiles)
if pocCount == 0 {
return
}
var wg sync.WaitGroup
var mu sync.Mutex
var successCount, failCount int
// 使用信号量控制并发数
semaphore := make(chan struct{}, concurrencyLimit)
for _, file := range pocFiles {
wg.Add(1)
semaphore <- struct{}{} // 获取信号量
go func(filename string) {
defer func() {
<-semaphore // 释放信号量
wg.Done()
}()
var poc *lib.Poc
var err error
// 根据不同的来源加载POC
if isEmbedded {
poc, err = lib.LoadPoc(filename, pocsFS)
} else {
poc, err = lib.LoadPocbyPath(filename)
}
mu.Lock()
defer mu.Unlock()
if err != nil {
failCount++
return
}
if poc != nil {
allPocs = append(allPocs, poc)
successCount++
}
}(file)
}
wg.Wait()
Common.LogSuccess(fmt.Sprintf("POC加载完成: 总共%d个成功%d个失败%d个",
pocCount, successCount, failCount))
}
// directoryExists 检查目录是否存在
func directoryExists(path string) bool {
info, err := os.Stat(path)
return err == nil && info.IsDir()
}
// isPocFile 检查文件是否为POC文件
func isPocFile(filename string) bool {
lowerName := strings.ToLower(filename)
return strings.HasSuffix(lowerName, yamlExt) || strings.HasSuffix(lowerName, ymlExt)
}
// filterPocs 根据POC名称筛选
func filterPocs(pocName string) []*lib.Poc {
if pocName == "" {
return allPocs
}
// 转换为小写以进行不区分大小写的匹配
searchName := strings.ToLower(pocName)
var matchedPocs []*lib.Poc
for _, poc := range AllPocs {
if strings.Contains(poc.Name, pocname) {
for _, poc := range allPocs {
if poc != nil && strings.Contains(strings.ToLower(poc.Name), searchName) {
matchedPocs = append(matchedPocs, poc)
}
}
Common.LogDebug(fmt.Sprintf("POC筛选完成匹配到 %d 个", len(matchedPocs)))
return matchedPocs
}

1
go.mod
View File

@ -27,6 +27,7 @@ require (
github.com/tomatome/grdp v0.0.0-20211231062539-be8adab7eaf3
golang.org/x/crypto v0.33.0
golang.org/x/net v0.35.0
golang.org/x/sync v0.11.0
golang.org/x/sys v0.30.0
golang.org/x/text v0.22.0
google.golang.org/genproto v0.0.0-20221027153422-115e99e71e1c