perf: 日常优化

This commit is contained in:
ZacharyZcR 2025-05-05 04:00:35 +08:00
parent 2b4a4024b8
commit 0dc4a6c360
5 changed files with 809 additions and 786 deletions

View File

@ -1,262 +1,151 @@
package Core package Core
import ( import (
"context"
"fmt" "fmt"
"github.com/shadow1ng/fscan/Common" "github.com/shadow1ng/fscan/Common"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
"net" "net"
"sort"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
// Addr 表示待扫描的地址 // EnhancedPortScan 高性能端口扫描函数
type Addr struct { func EnhancedPortScan(hosts []string, ports string, timeout int64) []string {
ip string // IP地址 // 解析端口和排除端口
port int // 端口号 portList := Common.ParsePort(ports)
} if len(portList) == 0 {
Common.LogError("无效端口: " + ports)
// ScanResult 扫描结果 return nil
type ScanResult struct {
Address string // IP地址
Port int // 端口号
Service *ServiceInfo // 服务信息
}
// PortScan 执行端口扫描
// hostslist: 待扫描的主机列表
// ports: 待扫描的端口范围
// timeout: 超时时间(秒)
// 返回活跃地址列表
func PortScan(hostslist []string, ports string, timeout int64) []string {
var results []ScanResult
var aliveAddrs []string
var mu sync.Mutex
// 解析并验证端口列表
probePorts := Common.ParsePort(ports)
if len(probePorts) == 0 {
Common.LogError(fmt.Sprintf("端口格式错误: %s", ports))
return aliveAddrs
} }
// 排除指定端口 exclude := make(map[int]struct{})
probePorts = excludeNoPorts(probePorts) for _, p := range Common.ParsePort(Common.ExcludePorts) {
exclude[p] = struct{}{}
}
// 初始化并发控制 // 初始化并发控制
workers := Common.ThreadNum ctx, cancel := context.WithCancel(context.Background())
addrs := make(chan Addr, 100) // 待扫描地址通道 defer cancel()
scanResults := make(chan ScanResult, 100) // 扫描结果通道 to := time.Duration(timeout) * time.Second
var wg sync.WaitGroup sem := semaphore.NewWeighted(int64(Common.ThreadNum))
var workerWg sync.WaitGroup var count int64
var aliveMap sync.Map
g, ctx := errgroup.WithContext(ctx)
// 启动扫描工作协程 // 并发扫描所有目标
for i := 0; i < workers; i++ { for _, host := range hosts {
workerWg.Add(1) for _, port := range portList {
go func() { if _, excluded := exclude[port]; excluded {
defer workerWg.Done() continue
for addr := range addrs {
PortConnect(addr, scanResults, timeout, &wg)
} }
}()
}
// 启动结果处理协程 host, port := host, port // 捕获循环变量
var resultWg sync.WaitGroup addr := fmt.Sprintf("%s:%d", host, port)
resultWg.Add(1)
go func() {
defer resultWg.Done()
for result := range scanResults {
mu.Lock()
results = append(results, result)
aliveAddr := fmt.Sprintf("%s:%d", result.Address, result.Port)
aliveAddrs = append(aliveAddrs, aliveAddr)
mu.Unlock()
}
}()
// 分发扫描任务 if err := sem.Acquire(ctx, 1); err != nil {
for _, port := range probePorts { break
for _, host := range hostslist { }
wg.Add(1)
addrs <- Addr{host, port} g.Go(func() error {
defer sem.Release(1)
// 连接测试
conn, err := net.DialTimeout("tcp", addr, to)
if err != nil {
return nil
}
defer conn.Close()
// 记录开放端口
atomic.AddInt64(&count, 1)
aliveMap.Store(addr, struct{}{})
Common.LogSuccess("端口开放 " + addr)
Common.SaveResult(&Common.ScanResult{
Time: time.Now(), Type: Common.PORT, Target: host,
Status: "open", Details: map[string]interface{}{"port": port},
})
// 服务识别
if Common.EnableFingerprint {
if info, err := NewPortInfoScanner(host, port, conn, to).Identify(); err == nil {
// 构建结果详情
details := map[string]interface{}{"port": port, "service": info.Name}
if info.Version != "" {
details["version"] = info.Version
}
// 处理额外信息
for k, v := range info.Extras {
if v == "" {
continue
}
switch k {
case "vendor_product":
details["product"] = v
case "os", "info":
details[k] = v
}
}
if len(info.Banner) > 0 {
details["banner"] = strings.TrimSpace(info.Banner)
}
// 保存服务结果
Common.SaveResult(&Common.ScanResult{
Time: time.Now(), Type: Common.SERVICE, Target: host,
Status: "identified", Details: details,
})
// 记录服务信息
var sb strings.Builder
sb.WriteString("服务识别 " + addr + " => ")
if info.Name != "unknown" {
sb.WriteString("[" + info.Name + "]")
}
if info.Version != "" {
sb.WriteString(" 版本:" + info.Version)
}
for k, v := range info.Extras {
if v == "" {
continue
}
switch k {
case "vendor_product":
sb.WriteString(" 产品:" + v)
case "os":
sb.WriteString(" 系统:" + v)
case "info":
sb.WriteString(" 信息:" + v)
}
}
if len(info.Banner) > 0 && len(info.Banner) < 100 {
sb.WriteString(" Banner:[" + strings.TrimSpace(info.Banner) + "]")
}
Common.LogSuccess(sb.String())
}
}
return nil
})
} }
} }
// 等待所有任务完成 _ = g.Wait()
close(addrs)
workerWg.Wait()
wg.Wait()
close(scanResults)
resultWg.Wait()
// 收集结果
var aliveAddrs []string
aliveMap.Range(func(key, _ interface{}) bool {
aliveAddrs = append(aliveAddrs, key.(string))
return true
})
Common.LogSuccess(fmt.Sprintf("扫描完成, 发现 %d 个开放端口", count))
return aliveAddrs return aliveAddrs
} }
// PortConnect 执行单个端口连接检测
// addr: 待检测的地址
// results: 结果通道
// timeout: 超时时间
// wg: 等待组
func PortConnect(addr Addr, results chan<- ScanResult, timeout int64, wg *sync.WaitGroup) {
defer wg.Done()
var isOpen bool
var err error
var conn net.Conn
// 尝试建立TCP连接
conn, err = Common.WrapperTcpWithTimeout("tcp4",
fmt.Sprintf("%s:%v", addr.ip, addr.port),
time.Duration(timeout)*time.Second)
if err == nil {
defer conn.Close()
isOpen = true
}
if err != nil || !isOpen {
return
}
// 记录开放端口
address := fmt.Sprintf("%s:%d", addr.ip, addr.port)
Common.LogSuccess(fmt.Sprintf("端口开放 %s", address))
// 保存端口扫描结果
portResult := &Common.ScanResult{
Time: time.Now(),
Type: Common.PORT,
Target: addr.ip,
Status: "open",
Details: map[string]interface{}{
"port": addr.port,
},
}
Common.SaveResult(portResult)
// 构造扫描结果
result := ScanResult{
Address: addr.ip,
Port: addr.port,
}
// 执行服务识别
if Common.EnableFingerprint && conn != nil {
scanner := NewPortInfoScanner(addr.ip, addr.port, conn, time.Duration(timeout)*time.Second)
if serviceInfo, err := scanner.Identify(); err == nil {
result.Service = serviceInfo
// 构造服务识别日志
var logMsg strings.Builder
logMsg.WriteString(fmt.Sprintf("服务识别 %s => ", address))
if serviceInfo.Name != "unknown" {
logMsg.WriteString(fmt.Sprintf("[%s]", serviceInfo.Name))
}
if serviceInfo.Version != "" {
logMsg.WriteString(fmt.Sprintf(" 版本:%s", serviceInfo.Version))
}
// 收集服务详细信息
details := map[string]interface{}{
"port": addr.port,
"service": serviceInfo.Name,
}
// 添加版本信息
if serviceInfo.Version != "" {
details["version"] = serviceInfo.Version
}
// 添加产品信息
if v, ok := serviceInfo.Extras["vendor_product"]; ok && v != "" {
details["product"] = v
logMsg.WriteString(fmt.Sprintf(" 产品:%s", v))
}
// 添加操作系统信息
if v, ok := serviceInfo.Extras["os"]; ok && v != "" {
details["os"] = v
logMsg.WriteString(fmt.Sprintf(" 系统:%s", v))
}
// 添加额外信息
if v, ok := serviceInfo.Extras["info"]; ok && v != "" {
details["info"] = v
logMsg.WriteString(fmt.Sprintf(" 信息:%s", v))
}
// 添加Banner信息
if len(serviceInfo.Banner) > 0 && len(serviceInfo.Banner) < 100 {
details["banner"] = strings.TrimSpace(serviceInfo.Banner)
logMsg.WriteString(fmt.Sprintf(" Banner:[%s]", strings.TrimSpace(serviceInfo.Banner)))
}
// 保存服务识别结果
serviceResult := &Common.ScanResult{
Time: time.Now(),
Type: Common.SERVICE,
Target: addr.ip,
Status: "identified",
Details: details,
}
Common.SaveResult(serviceResult)
Common.LogSuccess(logMsg.String())
}
}
results <- result
}
// NoPortScan 生成端口列表(不进行扫描)
// hostslist: 主机列表
// ports: 端口范围
// 返回地址列表
func NoPortScan(hostslist []string, ports string) []string {
var AliveAddress []string
// 解析并排除端口
probePorts := excludeNoPorts(Common.ParsePort(ports))
// 生成地址列表
for _, port := range probePorts {
for _, host := range hostslist {
address := fmt.Sprintf("%s:%d", host, port)
AliveAddress = append(AliveAddress, address)
}
}
return AliveAddress
}
// excludeNoPorts 排除指定的端口
// ports: 原始端口列表
// 返回过滤后的端口列表
func excludeNoPorts(ports []int) []int {
noPorts := Common.ParsePort(Common.ExcludePorts)
if len(noPorts) == 0 {
return ports
}
// 使用map过滤端口
temp := make(map[int]struct{})
for _, port := range ports {
temp[port] = struct{}{}
}
// 移除需要排除的端口
for _, port := range noPorts {
delete(temp, port)
}
// 转换为有序切片
var newPorts []int
for port := range temp {
newPorts = append(newPorts, port)
}
sort.Ints(newPorts)
return newPorts
}

View File

@ -92,7 +92,7 @@ func (s *ServiceScanStrategy) discoverAlivePorts(hosts []string) []string {
// 根据扫描模式选择端口扫描方式 // 根据扫描模式选择端口扫描方式
if len(hosts) > 0 { if len(hosts) > 0 {
alivePorts = PortScan(hosts, Common.Ports, Common.Timeout) alivePorts = EnhancedPortScan(hosts, Common.Ports, Common.Timeout)
Common.LogInfo(fmt.Sprintf("存活端口数量: %d", len(alivePorts))) Common.LogInfo(fmt.Sprintf("存活端口数量: %d", len(alivePorts)))
} }

View File

@ -8,9 +8,9 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url"
"regexp" "regexp"
"strings" "strings"
"sync"
"time" "time"
"unicode/utf8" "unicode/utf8"
@ -20,175 +20,203 @@ import (
"golang.org/x/text/encoding/simplifiedchinese" "golang.org/x/text/encoding/simplifiedchinese"
) )
// 常量定义
const (
maxTitleLength = 100
defaultProtocol = "http"
httpsProtocol = "https"
httpProtocol = "http"
printerFingerPrint = "打印机"
emptyTitle = "\"\""
noTitleText = "无标题"
// HTTP相关常量
httpPort = "80"
httpsPort = "443"
contentEncoding = "Content-Encoding"
gzipEncoding = "gzip"
contentLength = "Content-Length"
)
// 错误定义
var (
ErrNoTitle = fmt.Errorf("无法获取标题")
ErrHTTPClientInit = fmt.Errorf("HTTP客户端未初始化")
ErrReadRespBody = fmt.Errorf("读取响应内容失败")
)
// 响应结果
type WebResponse struct {
Url string
StatusCode int
Title string
Length string
Headers map[string]string
RedirectUrl string
Body []byte
Error error
}
// 协议检测结果
type ProtocolResult struct {
Protocol string
Success bool
}
// WebTitle 获取Web标题和指纹信息 // WebTitle 获取Web标题和指纹信息
func WebTitle(info *Common.HostInfo) error { func WebTitle(info *Common.HostInfo) error {
Common.LogDebug(fmt.Sprintf("开始获取Web标题初始信息: %+v", info)) if info == nil {
return fmt.Errorf("主机信息为空")
// 获取网站标题信息
err, CheckData := GOWebTitle(info)
Common.LogDebug(fmt.Sprintf("GOWebTitle执行完成 - 错误: %v, 检查数据长度: %d", err, len(CheckData)))
info.Infostr = WebScan.InfoCheck(info.Url, &CheckData)
Common.LogDebug(fmt.Sprintf("信息检查完成,获得信息: %v", info.Infostr))
// 检查是否为打印机,避免意外打印
for _, v := range info.Infostr {
if v == "打印机" {
Common.LogDebug("检测到打印机,停止扫描")
return nil
}
} }
// 输出错误信息(如果有) // 初始化Url
if err := initializeUrl(info); err != nil {
Common.LogError(fmt.Sprintf("初始化Url失败: %v", err))
return err
}
// 获取网站标题信息
checkData, err := fetchWebInfo(info)
if err != nil { if err != nil {
errlog := fmt.Sprintf("网站标题 %v %v", info.Url, err) // 记录错误但继续处理可能获取的数据
Common.LogError(errlog) Common.LogError(fmt.Sprintf("获取网站信息失败: %s %v", info.Url, err))
}
// 分析指纹
if len(checkData) > 0 {
info.Infostr = WebScan.InfoCheck(info.Url, &checkData)
// 检查是否为打印机,避免意外打印
for _, v := range info.Infostr {
if v == printerFingerPrint {
Common.LogInfo("检测到打印机,停止扫描")
return nil
}
}
} }
return err return err
} }
// GOWebTitle 获取网站标题并处理URL增强错误处理和协议切换 // 初始化Url根据主机和端口生成完整Url
func GOWebTitle(info *Common.HostInfo) (err error, CheckData []WebScan.CheckDatas) { func initializeUrl(info *Common.HostInfo) error {
Common.LogDebug(fmt.Sprintf("开始处理URL: %s", info.Url))
// 如果URL未指定根据端口生成URL
if info.Url == "" { if info.Url == "" {
Common.LogDebug("URL为空根据端口生成URL") // 根据端口推断Url
switch info.Ports { switch info.Ports {
case "80": case httpPort:
info.Url = fmt.Sprintf("http://%s", info.Host) info.Url = fmt.Sprintf("%s://%s", httpProtocol, info.Host)
case "443": case httpsPort:
info.Url = fmt.Sprintf("https://%s", info.Host) info.Url = fmt.Sprintf("%s://%s", httpsProtocol, info.Host)
default: default:
host := fmt.Sprintf("%s:%s", info.Host, info.Ports) host := fmt.Sprintf("%s:%s", info.Host, info.Ports)
Common.LogDebug(fmt.Sprintf("正在检测主机协议: %s", host)) protocol, err := detectProtocol(host, Common.Timeout)
protocol := GetProtocol(host, Common.Timeout) if err != nil {
Common.LogDebug(fmt.Sprintf("检测到协议: %s", protocol)) return fmt.Errorf("协议检测失败: %w", err)
}
info.Url = fmt.Sprintf("%s://%s:%s", protocol, info.Host, info.Ports) info.Url = fmt.Sprintf("%s://%s:%s", protocol, info.Host, info.Ports)
} }
} else { } else if !strings.Contains(info.Url, "://") {
// 处理未指定协议的URL // 处理未指定协议的Url
if !strings.Contains(info.Url, "://") { host := strings.Split(info.Url, "/")[0]
Common.LogDebug("URL未包含协议开始检测") protocol, err := detectProtocol(host, Common.Timeout)
host := strings.Split(info.Url, "/")[0]
protocol := GetProtocol(host, Common.Timeout)
Common.LogDebug(fmt.Sprintf("检测到协议: %s", protocol))
info.Url = fmt.Sprintf("%s://%s", protocol, info.Url)
}
}
Common.LogDebug(fmt.Sprintf("协议检测完成后的URL: %s", info.Url))
// 记录原始URL协议
originalProtocol := "http"
if strings.HasPrefix(info.Url, "https://") {
originalProtocol = "https"
}
// 第一次获取URL
Common.LogDebug("第一次尝试访问URL")
err, result, CheckData := geturl(info, 1, CheckData)
Common.LogDebug(fmt.Sprintf("第一次访问结果 - 错误: %v, 返回信息: %s", err, result))
// 如果访问失败并且使用的是HTTPS尝试降级到HTTP
if err != nil && !strings.Contains(err.Error(), "EOF") {
if originalProtocol == "https" {
Common.LogDebug("HTTPS访问失败尝试降级到HTTP")
// 替换协议部分
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
Common.LogDebug(fmt.Sprintf("降级后的URL: %s", info.Url))
err, result, CheckData = geturl(info, 1, CheckData)
Common.LogDebug(fmt.Sprintf("HTTP降级访问结果 - 错误: %v, 返回信息: %s", err, result))
// 如果仍然失败,返回错误
if err != nil && !strings.Contains(err.Error(), "EOF") {
return
}
} else {
// 如果本来就是HTTP并且失败了直接返回错误
return
}
}
// 处理URL跳转
if strings.Contains(result, "://") {
Common.LogDebug(fmt.Sprintf("检测到重定向到: %s", result))
info.Url = result
err, result, CheckData = geturl(info, 3, CheckData)
Common.LogDebug(fmt.Sprintf("重定向请求结果 - 错误: %v, 返回信息: %s", err, result))
if err != nil { if err != nil {
// 如果重定向跟踪失败,尝试降级协议 return fmt.Errorf("协议检测失败: %w", err)
if strings.HasPrefix(info.Url, "https://") {
Common.LogDebug("重定向HTTPS访问失败尝试降级到HTTP")
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
err, result, CheckData = geturl(info, 3, CheckData)
Common.LogDebug(fmt.Sprintf("重定向降级访问结果 - 错误: %v, 返回信息: %s", err, result))
}
if err != nil {
return
}
} }
info.Url = fmt.Sprintf("%s://%s", protocol, info.Url)
} }
// 处理HTTP到HTTPS的升级提示 return nil
if result == "https" && !strings.HasPrefix(info.Url, "https://") {
Common.LogDebug("正在升级到HTTPS")
info.Url = strings.Replace(info.Url, "http://", "https://", 1)
Common.LogDebug(fmt.Sprintf("升级后的URL: %s", info.Url))
err, result, CheckData = geturl(info, 1, CheckData)
Common.LogDebug(fmt.Sprintf("HTTPS升级访问结果 - 错误: %v, 返回信息: %s", err, result))
// 如果HTTPS升级后访问失败回退到HTTP
if err != nil && !strings.Contains(err.Error(), "EOF") {
Common.LogDebug("HTTPS升级访问失败回退到HTTP")
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
err, result, CheckData = geturl(info, 1, CheckData)
Common.LogDebug(fmt.Sprintf("回退到HTTP访问结果 - 错误: %v, 返回信息: %s", err, result))
}
// 处理升级后的跳转
if strings.Contains(result, "://") {
Common.LogDebug(fmt.Sprintf("协议升级后发现重定向到: %s", result))
info.Url = result
err, _, CheckData = geturl(info, 3, CheckData)
if err != nil {
// 如果重定向跟踪失败,再次尝试降级
if strings.HasPrefix(info.Url, "https://") {
Common.LogDebug("升级后重定向HTTPS访问失败尝试降级到HTTP")
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
err, _, CheckData = geturl(info, 3, CheckData)
}
}
}
}
Common.LogDebug(fmt.Sprintf("GOWebTitle执行完成 - 错误: %v", err))
return
} }
func geturl(info *Common.HostInfo, flag int, CheckData []WebScan.CheckDatas) (error, string, []WebScan.CheckDatas) { // 获取Web信息标题、指纹等
Common.LogDebug(fmt.Sprintf("geturl开始执行 - URL: %s, 标志位: %d", info.Url, flag)) func fetchWebInfo(info *Common.HostInfo) ([]WebScan.CheckDatas, error) {
var checkData []WebScan.CheckDatas
// 处理目标URL // 记录原始Url协议
Url := info.Url originalUrl := info.Url
if flag == 2 { isHTTPS := strings.HasPrefix(info.Url, "https://")
Common.LogDebug("处理favicon.ico URL")
URL, err := url.Parse(Url) // 第一次尝试访问Url
if err == nil { resp, err := fetchUrlWithRetry(info, false, &checkData)
Url = fmt.Sprintf("%s://%s/favicon.ico", URL.Scheme, URL.Host)
// 处理不同的错误情况
if err != nil {
// 如果是HTTPS并失败尝试降级到HTTP
if isHTTPS {
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
resp, err = fetchUrlWithRetry(info, false, &checkData)
// 如果HTTP也失败恢复原始Url并返回错误
if err != nil {
info.Url = originalUrl
return checkData, err
}
} else { } else {
Url += "/favicon.ico" return checkData, err
} }
Common.LogDebug(fmt.Sprintf("favicon URL: %s", Url))
} }
// 创建HTTP请求 // 处理重定向
Common.LogDebug("开始创建HTTP请求") if resp != nil && resp.RedirectUrl != "" {
req, err := http.NewRequest("GET", Url, nil) info.Url = resp.RedirectUrl
resp, err = fetchUrlWithRetry(info, true, &checkData)
// 如果重定向后失败,尝试降级协议
if err != nil && strings.HasPrefix(info.Url, "https://") {
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
resp, err = fetchUrlWithRetry(info, true, &checkData)
}
}
// 处理需要升级到HTTPS的情况
if resp != nil && resp.StatusCode == 400 && !strings.HasPrefix(info.Url, "https://") {
info.Url = strings.Replace(info.Url, "http://", "https://", 1)
resp, err = fetchUrlWithRetry(info, false, &checkData)
// 如果HTTPS升级失败回退到HTTP
if err != nil {
info.Url = strings.Replace(info.Url, "https://", "http://", 1)
resp, err = fetchUrlWithRetry(info, false, &checkData)
}
// 处理升级后的重定向
if resp != nil && resp.RedirectUrl != "" {
info.Url = resp.RedirectUrl
resp, err = fetchUrlWithRetry(info, true, &checkData)
}
}
return checkData, err
}
// 尝试获取Url支持重试
func fetchUrlWithRetry(info *Common.HostInfo, followRedirect bool, checkData *[]WebScan.CheckDatas) (*WebResponse, error) {
// 获取页面内容
resp, err := fetchUrl(info.Url, followRedirect)
if err != nil { if err != nil {
Common.LogDebug(fmt.Sprintf("创建HTTP请求失败: %v", err)) return nil, err
return err, "", CheckData }
// 保存检查数据
if resp.Body != nil && len(resp.Body) > 0 {
headers := fmt.Sprintf("%v", resp.Headers)
*checkData = append(*checkData, WebScan.CheckDatas{resp.Body, headers})
}
// 保存扫描结果
if resp.StatusCode > 0 {
saveWebResult(info, resp)
}
return resp, nil
}
// 抓取Url内容
func fetchUrl(targetUrl string, followRedirect bool) (*WebResponse, error) {
// 创建HTTP请求
req, err := http.NewRequest("GET", targetUrl, nil)
if err != nil {
return nil, fmt.Errorf("创建HTTP请求失败: %w", err)
} }
// 设置请求头 // 设置请求头
@ -199,378 +227,327 @@ func geturl(info *Common.HostInfo, flag int, CheckData []WebScan.CheckDatas) (er
req.Header.Set("Cookie", Common.Cookie) req.Header.Set("Cookie", Common.Cookie)
} }
req.Header.Set("Connection", "close") req.Header.Set("Connection", "close")
Common.LogDebug("已设置请求头")
// 选择HTTP客户端 // 选择HTTP客户端
var client *http.Client var client *http.Client
if flag == 1 { if followRedirect {
client = lib.ClientNoRedirect
Common.LogDebug("使用不跟随重定向的客户端")
} else {
client = lib.Client client = lib.Client
Common.LogDebug("使用普通客户端") } else {
client = lib.ClientNoRedirect
} }
// 检查客户端是否为空
if client == nil { if client == nil {
Common.LogDebug("错误: HTTP客户端为空") return nil, ErrHTTPClientInit
return fmt.Errorf("HTTP客户端未初始化"), "", CheckData
} }
// 发送请求 // 发送请求
Common.LogDebug("开始发送HTTP请求")
resp, err := client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
Common.LogDebug(fmt.Sprintf("HTTP请求失败: %v", err)) // 特殊处理SSL/TLS相关错误
return err, "https", CheckData errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, "tls") || strings.Contains(errMsg, "ssl") ||
strings.Contains(errMsg, "handshake") || strings.Contains(errMsg, "certificate") {
return &WebResponse{Error: err}, nil
}
return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
Common.LogDebug(fmt.Sprintf("收到HTTP响应状态码: %d", resp.StatusCode))
// 准备响应结果
result := &WebResponse{
Url: req.URL.String(),
StatusCode: resp.StatusCode,
Headers: make(map[string]string),
}
// 提取响应头
for k, v := range resp.Header {
if len(v) > 0 {
result.Headers[k] = v[0]
}
}
// 获取内容长度
result.Length = resp.Header.Get(contentLength)
// 检查重定向
redirectUrl, err := resp.Location()
if err == nil {
result.RedirectUrl = redirectUrl.String()
}
// 读取响应内容 // 读取响应内容
body, err := getRespBody(resp) body, err := readResponseBody(resp)
if err != nil { if err != nil {
Common.LogDebug(fmt.Sprintf("读取响应内容失败: %v", err)) return result, fmt.Errorf("读取响应内容失败: %w", err)
return err, "https", CheckData
} }
Common.LogDebug(fmt.Sprintf("成功读取响应内容,长度: %d", len(body))) result.Body = body
// 保存检查数据 // 提取标题
CheckData = append(CheckData, WebScan.CheckDatas{body, fmt.Sprintf("%s", resp.Header)}) if !utf8.Valid(body) {
Common.LogDebug("已保存检查数据") body, _ = simplifiedchinese.GBK.NewDecoder().Bytes(body)
}
result.Title = extractTitle(body)
// 处理非favicon请求 if result.Length == "" {
var reurl string result.Length = fmt.Sprintf("%d", len(body))
if flag != 2 {
// 处理编码
if !utf8.Valid(body) {
body, _ = simplifiedchinese.GBK.NewDecoder().Bytes(body)
}
// 获取页面信息
title := gettitle(body)
length := resp.Header.Get("Content-Length")
if length == "" {
length = fmt.Sprintf("%v", len(body))
}
// 收集服务器信息
serverInfo := make(map[string]interface{})
serverInfo["title"] = title
serverInfo["length"] = length
serverInfo["status_code"] = resp.StatusCode
// 收集响应头信息
for k, v := range resp.Header {
if len(v) > 0 {
serverInfo[strings.ToLower(k)] = v[0]
}
}
// 检查重定向
redirURL, err1 := resp.Location()
if err1 == nil {
reurl = redirURL.String()
serverInfo["redirect_url"] = reurl
}
// 处理指纹信息 - 添加调试日志
Common.LogDebug(fmt.Sprintf("保存结果前的指纹信息: %v", info.Infostr))
// 处理空指纹情况
fingerprints := info.Infostr
if len(fingerprints) == 1 && fingerprints[0] == "" {
// 如果是只包含空字符串的数组,替换为空数组
fingerprints = []string{}
Common.LogDebug("检测到空指纹,已转换为空数组")
}
// 保存扫描结果
result := &Common.ScanResult{
Time: time.Now(),
Type: Common.SERVICE,
Target: info.Host,
Status: "identified",
Details: map[string]interface{}{
"port": info.Ports,
"service": "http",
"title": title,
"url": resp.Request.URL.String(),
"status_code": resp.StatusCode,
"length": length,
"server_info": serverInfo,
"fingerprints": fingerprints, // 使用处理过的指纹信息
},
}
Common.SaveResult(result)
Common.LogDebug(fmt.Sprintf("已保存结果,指纹信息: %v", fingerprints))
// 输出控制台日志
logMsg := fmt.Sprintf("网站标题 %-25v 状态码:%-3v 长度:%-6v 标题:%v",
resp.Request.URL, resp.StatusCode, length, title)
if reurl != "" {
logMsg += fmt.Sprintf(" 重定向地址: %s", reurl)
}
// 添加指纹信息到控制台日志
if len(fingerprints) > 0 {
logMsg += fmt.Sprintf(" 指纹:%v", fingerprints)
}
Common.LogSuccess(logMsg)
} }
// 返回结果 return result, nil
if reurl != "" {
Common.LogDebug(fmt.Sprintf("返回重定向URL: %s", reurl))
return nil, reurl, CheckData
}
if resp.StatusCode == 400 && !strings.HasPrefix(info.Url, "https") {
Common.LogDebug("返回HTTPS升级标志")
return nil, "https", CheckData
}
Common.LogDebug("geturl执行完成无特殊返回")
return nil, "", CheckData
} }
// getRespBody 读取HTTP响应体内容 // 读取HTTP响应体内容
func getRespBody(oResp *http.Response) ([]byte, error) { func readResponseBody(resp *http.Response) ([]byte, error) {
Common.LogDebug("开始读取响应体内容")
var body []byte var body []byte
var reader io.Reader = resp.Body
// 处理gzip压缩的响应 // 处理gzip压缩的响应
if oResp.Header.Get("Content-Encoding") == "gzip" { if resp.Header.Get(contentEncoding) == gzipEncoding {
Common.LogDebug("检测到gzip压缩开始解压") gr, err := gzip.NewReader(resp.Body)
gr, err := gzip.NewReader(oResp.Body)
if err != nil { if err != nil {
Common.LogDebug(fmt.Sprintf("创建gzip解压器失败: %v", err)) return nil, fmt.Errorf("创建gzip解压器失败: %w", err)
return nil, err
} }
defer gr.Close() defer gr.Close()
reader = gr
// 循环读取解压内容
for {
buf := make([]byte, 1024)
n, err := gr.Read(buf)
if err != nil && err != io.EOF {
Common.LogDebug(fmt.Sprintf("读取压缩内容失败: %v", err))
return nil, err
}
if n == 0 {
break
}
body = append(body, buf...)
}
Common.LogDebug(fmt.Sprintf("gzip解压完成内容长度: %d", len(body)))
} else {
// 直接读取未压缩的响应
Common.LogDebug("读取未压缩的响应内容")
raw, err := io.ReadAll(oResp.Body)
if err != nil {
Common.LogDebug(fmt.Sprintf("读取响应内容失败: %v", err))
return nil, err
}
body = raw
Common.LogDebug(fmt.Sprintf("读取完成,内容长度: %d", len(body)))
} }
// 读取内容
body, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("读取响应内容失败: %w", err)
}
return body, nil return body, nil
} }
// gettitle 从HTML内容中提取网页标题 // 提取网页标题
func gettitle(body []byte) (title string) { func extractTitle(body []byte) string {
Common.LogDebug("开始提取网页标题")
// 使用正则表达式匹配title标签内容 // 使用正则表达式匹配title标签内容
re := regexp.MustCompile("(?ims)<title.*?>(.*?)</title>") re := regexp.MustCompile("(?ims)<title.*?>(.*?)</title>")
find := re.FindSubmatch(body) find := re.FindSubmatch(body)
if len(find) > 1 { if len(find) > 1 {
title = string(find[1]) title := string(find[1])
Common.LogDebug(fmt.Sprintf("找到原始标题: %s", title))
// 清理标题内容 // 清理标题内容
title = strings.TrimSpace(title) // 去除首尾空格 title = strings.TrimSpace(title)
title = strings.Replace(title, "\n", "", -1) // 去除换行 title = strings.Replace(title, "\n", "", -1)
title = strings.Replace(title, "\r", "", -1) // 去除回车 title = strings.Replace(title, "\r", "", -1)
title = strings.Replace(title, "&nbsp;", " ", -1) // 替换HTML空格 title = strings.Replace(title, "&nbsp;", " ", -1)
// 截断过长的标题 // 截断过长的标题
if len(title) > 100 { if len(title) > maxTitleLength {
Common.LogDebug("标题超过100字符进行截断") title = title[:maxTitleLength]
title = title[:100]
} }
// 处理空标题 // 处理空标题
if title == "" { if title == "" {
Common.LogDebug("标题为空,使用双引号代替") return emptyTitle
title = "\"\""
} }
} else {
Common.LogDebug("未找到标题标签") return title
title = "无标题"
} }
Common.LogDebug(fmt.Sprintf("最终标题: %s", title))
return return noTitleText
} }
// GetProtocol 检测目标主机的协议类型(HTTP/HTTPS),优先返回可用的协议 // 保存Web扫描结果
func GetProtocol(host string, Timeout int64) (protocol string) { func saveWebResult(info *Common.HostInfo, resp *WebResponse) {
Common.LogDebug(fmt.Sprintf("开始检测主机协议 - 主机: %s, 超时: %d秒", host, Timeout)) // 处理指纹信息
fingerprints := info.Infostr
if len(fingerprints) == 1 && fingerprints[0] == "" {
fingerprints = []string{}
}
// 默认使用http协议 // 准备服务器信息
protocol = "http" serverInfo := make(map[string]interface{})
serverInfo["title"] = resp.Title
serverInfo["length"] = resp.Length
serverInfo["status_code"] = resp.StatusCode
timeoutDuration := time.Duration(Timeout) * time.Second // 添加响应头信息
for k, v := range resp.Headers {
serverInfo[strings.ToLower(k)] = v
}
// 添加重定向信息
if resp.RedirectUrl != "" {
serverInfo["redirect_Url"] = resp.RedirectUrl
}
// 保存扫描结果
result := &Common.ScanResult{
Time: time.Now(),
Type: Common.SERVICE,
Target: info.Host,
Status: "identified",
Details: map[string]interface{}{
"port": info.Ports,
"service": "http",
"title": resp.Title,
"Url": resp.Url,
"status_code": resp.StatusCode,
"length": resp.Length,
"server_info": serverInfo,
"fingerprints": fingerprints,
},
}
Common.SaveResult(result)
// 输出控制台日志
logMsg := fmt.Sprintf("网站标题 %-25v 状态码:%-3v 长度:%-6v 标题:%v",
resp.Url, resp.StatusCode, resp.Length, resp.Title)
if resp.RedirectUrl != "" {
logMsg += fmt.Sprintf(" 重定向地址: %s", resp.RedirectUrl)
}
if len(fingerprints) > 0 {
logMsg += fmt.Sprintf(" 指纹:%v", fingerprints)
}
Common.LogSuccess(logMsg)
}
// 检测目标主机的协议类型(HTTP/HTTPS)
func detectProtocol(host string, timeout int64) (string, error) {
// 根据标准端口快速判断协议
if strings.HasSuffix(host, ":"+httpPort) {
return httpProtocol, nil
} else if strings.HasSuffix(host, ":"+httpsPort) {
return httpsProtocol, nil
}
timeoutDuration := time.Duration(timeout) * time.Second
ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration) ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration)
defer cancel() defer cancel()
// 1. 根据标准端口快速判断协议 // 并发检测HTTP和HTTPS
if strings.HasSuffix(host, ":80") { resultChan := make(chan ProtocolResult, 2)
Common.LogDebug("检测到标准HTTP端口使用HTTP协议") wg := sync.WaitGroup{}
return "http" wg.Add(2)
} else if strings.HasSuffix(host, ":443") {
Common.LogDebug("检测到标准HTTPS端口使用HTTPS协议")
return "https"
}
// 2. 并发检测HTTP和HTTPS
type protocolResult struct {
name string
success bool
}
resultChan := make(chan protocolResult, 2)
singleTimeout := timeoutDuration / 2 // 每个协议检测的超时时间减半
// 检测HTTPS // 检测HTTPS
go func() { go func() {
Common.LogDebug("开始检测HTTPS协议") defer wg.Done()
tlsConfig := &tls.Config{ success := checkHTTPS(host, timeoutDuration/2)
InsecureSkipVerify: true, select {
MinVersion: tls.VersionTLS10, case resultChan <- ProtocolResult{httpsProtocol, success}:
case <-ctx.Done():
} }
dialer := &net.Dialer{
Timeout: singleTimeout,
}
conn, err := tls.DialWithDialer(dialer, "tcp", host, tlsConfig)
if err == nil {
Common.LogDebug("HTTPS连接成功")
conn.Close()
resultChan <- protocolResult{"https", true}
return
}
// 分析TLS错误
if err != nil {
errMsg := strings.ToLower(err.Error())
// 这些错误可能表明服务器确实支持TLS但有其他问题
if strings.Contains(errMsg, "handshake failure") ||
strings.Contains(errMsg, "certificate") ||
strings.Contains(errMsg, "tls") ||
strings.Contains(errMsg, "x509") ||
strings.Contains(errMsg, "secure") {
Common.LogDebug(fmt.Sprintf("TLS握手有错误但可能是HTTPS协议: %v", err))
resultChan <- protocolResult{"https", true}
return
}
Common.LogDebug(fmt.Sprintf("HTTPS连接失败: %v", err))
}
resultChan <- protocolResult{"https", false}
}() }()
// 检测HTTP // 检测HTTP
go func() { go func() {
Common.LogDebug("开始检测HTTP协议") defer wg.Done()
req, err := http.NewRequestWithContext(ctx, "HEAD", fmt.Sprintf("http://%s", host), nil) success := checkHTTP(ctx, host, timeoutDuration/2)
if err != nil { select {
Common.LogDebug(fmt.Sprintf("创建HTTP请求失败: %v", err)) case resultChan <- ProtocolResult{httpProtocol, success}:
resultChan <- protocolResult{"http", false} case <-ctx.Done():
return
} }
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
DialContext: (&net.Dialer{
Timeout: singleTimeout,
}).DialContext,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // 不跟随重定向
},
Timeout: singleTimeout,
}
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
Common.LogDebug(fmt.Sprintf("HTTP连接成功状态码: %d", resp.StatusCode))
resultChan <- protocolResult{"http", true}
return
}
Common.LogDebug(fmt.Sprintf("标准HTTP请求失败: %v尝试原始TCP连接", err))
// 尝试原始TCP连接和简单HTTP请求
netConn, err := net.DialTimeout("tcp", host, singleTimeout)
if err == nil {
defer netConn.Close()
netConn.SetDeadline(time.Now().Add(singleTimeout))
// 发送简单HTTP请求
_, err = netConn.Write([]byte("HEAD / HTTP/1.0\r\nHost: " + host + "\r\n\r\n"))
if err == nil {
// 读取响应
buf := make([]byte, 1024)
netConn.SetDeadline(time.Now().Add(singleTimeout))
n, err := netConn.Read(buf)
if err == nil && n > 0 {
response := string(buf[:n])
if strings.Contains(response, "HTTP/") {
Common.LogDebug("通过原始TCP连接确认HTTP协议")
resultChan <- protocolResult{"http", true}
return
}
}
}
Common.LogDebug("原始TCP连接成功但HTTP响应无效")
} else {
Common.LogDebug(fmt.Sprintf("原始TCP连接失败: %v", err))
}
resultChan <- protocolResult{"http", false}
}() }()
// 3. 收集结果并决定使用哪种协议 // 确保所有goroutine正常退出
var httpsSuccess, httpSuccess bool go func() {
wg.Wait()
close(resultChan)
}()
// 等待两个goroutine返回结果或超时 // 收集结果
for i := 0; i < 2; i++ { var httpsResult, httpResult *ProtocolResult
select {
case result := <-resultChan: for result := range resultChan {
if result.name == "https" { if result.Protocol == httpsProtocol {
httpsSuccess = result.success r := result
Common.LogDebug(fmt.Sprintf("HTTPS检测结果: %v", httpsSuccess)) httpsResult = &r
} else if result.name == "http" { } else if result.Protocol == httpProtocol {
httpSuccess = result.success r := result
Common.LogDebug(fmt.Sprintf("HTTP检测结果: %v", httpSuccess)) httpResult = &r
}
case <-ctx.Done():
Common.LogDebug("协议检测超时")
break
} }
} }
// 4. 决定使用哪种协议 - 优先使用HTTPS如果HTTPS不可用则使用HTTP // 决定使用哪种协议 - 优先使用HTTPS
if httpsSuccess { if httpsResult != nil && httpsResult.Success {
Common.LogDebug("选择使用HTTPS协议") return httpsProtocol, nil
return "https" } else if httpResult != nil && httpResult.Success {
} else if httpSuccess { return httpProtocol, nil
Common.LogDebug("选择使用HTTP协议")
return "http"
} }
// 5. 如果两种协议都无法确认,保持默认值 // 默认使用HTTP
Common.LogDebug(fmt.Sprintf("无法确定协议,使用默认协议: %s", protocol)) return defaultProtocol, nil
return }
// 检测HTTPS协议
func checkHTTPS(host string, timeout time.Duration) bool {
tlsConfig := &tls.Config{
InsecureSkipVerify: true,
MinVersion: tls.VersionTLS10,
}
dialer := &net.Dialer{
Timeout: timeout,
}
conn, err := tls.DialWithDialer(dialer, "tcp", host, tlsConfig)
if err == nil {
conn.Close()
return true
}
// 分析TLS错误某些错误可能表明服务器支持TLS但有其他问题
errMsg := strings.ToLower(err.Error())
return strings.Contains(errMsg, "handshake failure") ||
strings.Contains(errMsg, "certificate") ||
strings.Contains(errMsg, "tls") ||
strings.Contains(errMsg, "x509") ||
strings.Contains(errMsg, "secure")
}
// 检测HTTP协议
func checkHTTP(ctx context.Context, host string, timeout time.Duration) bool {
req, err := http.NewRequestWithContext(ctx, "HEAD", fmt.Sprintf("http://%s", host), nil)
if err != nil {
return false
}
client := &http.Client{
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
DialContext: (&net.Dialer{
Timeout: timeout,
}).DialContext,
},
CheckRedirect: func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse // 不跟随重定向
},
Timeout: timeout,
}
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
return true
}
// 尝试原始TCP连接和简单HTTP请求
netConn, err := net.DialTimeout("tcp", host, timeout)
if err == nil {
defer netConn.Close()
netConn.SetDeadline(time.Now().Add(timeout))
// 发送简单HTTP请求
_, err = netConn.Write([]byte("HEAD / HTTP/1.0\r\nHost: " + host + "\r\n\r\n"))
if err == nil {
// 读取响应
buf := make([]byte, 1024)
netConn.SetDeadline(time.Now().Add(timeout))
n, err := netConn.Read(buf)
if err == nil && n > 0 {
response := string(buf[:n])
return strings.Contains(response, "HTTP/")
}
}
}
return false
} }

View File

@ -1,90 +1,171 @@
package WebScan package WebScan
import ( import (
"context"
"embed" "embed"
"errors"
"fmt" "fmt"
"github.com/shadow1ng/fscan/Common"
"github.com/shadow1ng/fscan/WebScan/lib"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"sync" "sync"
"time"
"github.com/shadow1ng/fscan/Common"
"github.com/shadow1ng/fscan/WebScan/lib"
)
// 常量定义
const (
protocolHTTP = "http://"
protocolHTTPS = "https://"
yamlExt = ".yaml"
ymlExt = ".yml"
defaultTimeout = 30 * time.Second
concurrencyLimit = 10 // 并发加载POC的限制
)
// 错误定义
var (
ErrInvalidURL = errors.New("无效的URL格式")
ErrEmptyTarget = errors.New("目标URL为空")
ErrPocNotFound = errors.New("未找到匹配的POC")
ErrPocLoadFailed = errors.New("POC加载失败")
) )
//go:embed pocs //go:embed pocs
var Pocs embed.FS var pocsFS embed.FS
var once sync.Once var (
var AllPocs []*lib.Poc once sync.Once
allPocs []*lib.Poc
)
// WebScan 执行Web漏洞扫描 // WebScan 执行Web漏洞扫描
func WebScan(info *Common.HostInfo) { func WebScan(info *Common.HostInfo) {
once.Do(initpoc) // 初始化POC
once.Do(initPocs)
var pocinfo = Common.Pocinfo // 验证输入
if info == nil {
// 自动构建URL Common.LogError("无效的扫描目标")
if info.Url == "" { return
info.Url = fmt.Sprintf("http://%s:%s", info.Host, info.Ports)
} }
urlParts := strings.Split(info.Url, "/") if len(allPocs) == 0 {
Common.LogError("POC加载失败无法执行扫描")
// 检查切片长度并构建目标URL return
if len(urlParts) >= 3 {
pocinfo.Target = strings.Join(urlParts[:3], "/")
} else {
pocinfo.Target = info.Url
} }
Common.LogDebug(fmt.Sprintf("扫描目标: %s", pocinfo.Target)) // 构建目标URL
target, err := buildTargetURL(info)
if err != nil {
Common.LogError(fmt.Sprintf("构建目标URL失败: %v", err))
return
}
// 如果是直接调用WebPoc没有指定pocName执行所有POC // 使用带超时的上下文
if pocinfo.PocName == "" && len(info.Infostr) == 0 { ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
Common.LogDebug("直接调用WebPoc执行所有POC") defer cancel()
Execute(pocinfo)
} else { // 根据扫描策略执行POC
// 根据指纹信息选择性执行POC if Common.Pocinfo.PocName == "" && len(info.Infostr) == 0 {
if len(info.Infostr) > 0 { // 执行所有POC
for _, infostr := range info.Infostr { executePOCs(ctx, Common.PocInfo{Target: target})
pocinfo.PocName = lib.CheckInfoPoc(infostr) } else if len(info.Infostr) > 0 {
if pocinfo.PocName != "" { // 基于指纹信息执行POC
Common.LogDebug(fmt.Sprintf("根据指纹 %s 执行对应POC", infostr)) scanByFingerprints(ctx, target, info.Infostr)
Execute(pocinfo) } else if Common.Pocinfo.PocName != "" {
} // 基于指定POC名称执行
} executePOCs(ctx, Common.PocInfo{Target: target, PocName: Common.Pocinfo.PocName})
} else if pocinfo.PocName != "" {
// 指定了特定的POC
Common.LogDebug(fmt.Sprintf("执行指定POC: %s", pocinfo.PocName))
Execute(pocinfo)
}
} }
} }
// Execute 执行具体的POC检测 // buildTargetURL 构建规范的目标URL
func Execute(PocInfo Common.PocInfo) { func buildTargetURL(info *Common.HostInfo) (string, error) {
Common.LogDebug(fmt.Sprintf("开始执行POC检测目标: %s", PocInfo.Target)) // 自动构建URL
if info.Url == "" {
info.Url = fmt.Sprintf("%s%s:%s", protocolHTTP, info.Host, info.Ports)
} else if !hasProtocolPrefix(info.Url) {
info.Url = protocolHTTP + info.Url
}
// 解析URL以提取基础部分
parsedURL, err := url.Parse(info.Url)
if err != nil {
return "", fmt.Errorf("%w: %v", ErrInvalidURL, err)
}
return fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host), nil
}
// hasProtocolPrefix 检查URL是否包含协议前缀
func hasProtocolPrefix(urlStr string) bool {
return strings.HasPrefix(urlStr, protocolHTTP) || strings.HasPrefix(urlStr, protocolHTTPS)
}
// scanByFingerprints 根据指纹执行POC
func scanByFingerprints(ctx context.Context, target string, fingerprints []string) {
for _, fingerprint := range fingerprints {
if fingerprint == "" {
continue
}
pocName := lib.CheckInfoPoc(fingerprint)
if pocName == "" {
continue
}
executePOCs(ctx, Common.PocInfo{Target: target, PocName: pocName})
}
}
// executePOCs 执行POC检测
func executePOCs(ctx context.Context, pocInfo Common.PocInfo) {
// 验证目标
if pocInfo.Target == "" {
Common.LogError(ErrEmptyTarget.Error())
return
}
// 确保URL格式正确 // 确保URL格式正确
if !strings.HasPrefix(PocInfo.Target, "http://") && !strings.HasPrefix(PocInfo.Target, "https://") { if !hasProtocolPrefix(pocInfo.Target) {
PocInfo.Target = "http://" + PocInfo.Target pocInfo.Target = protocolHTTP + pocInfo.Target
} }
// 验证URL格式 // 验证URL
_, err := url.Parse(PocInfo.Target) _, err := url.Parse(pocInfo.Target)
if err != nil { if err != nil {
Common.LogError(fmt.Sprintf("无效的URL格式 %v: %v", PocInfo.Target, err)) Common.LogError(fmt.Sprintf("%v %s: %v", ErrInvalidURL, pocInfo.Target, err))
return return
} }
// 创建基础HTTP请求 // 创建基础请求
req, err := http.NewRequest("GET", PocInfo.Target, nil) req, err := createBaseRequest(ctx, pocInfo.Target)
if err != nil { if err != nil {
Common.LogError(fmt.Sprintf("初始化请求失败 %v: %v", PocInfo.Target, err)) Common.LogError(fmt.Sprintf("创建HTTP请求失败: %v", err))
return return
} }
// 筛选POC
matchedPocs := filterPocs(pocInfo.PocName)
if len(matchedPocs) == 0 {
Common.LogDebug(fmt.Sprintf("%v: %s", ErrPocNotFound, pocInfo.PocName))
return
}
// 执行POC检测
lib.CheckMultiPoc(req, matchedPocs, Common.PocNum)
}
// createBaseRequest 创建带上下文的HTTP请求
func createBaseRequest(ctx context.Context, target string) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, "GET", target, nil)
if err != nil {
return nil, err
}
// 设置请求头 // 设置请求头
req.Header.Set("User-agent", Common.UserAgent) req.Header.Set("User-agent", Common.UserAgent)
req.Header.Set("Accept", Common.Accept) req.Header.Set("Accept", Common.Accept)
@ -93,75 +174,150 @@ func Execute(PocInfo Common.PocInfo) {
req.Header.Set("Cookie", Common.Cookie) req.Header.Set("Cookie", Common.Cookie)
} }
// 根据名称筛选POC并执行 return req, nil
pocs := filterPoc(PocInfo.PocName)
Common.LogDebug(fmt.Sprintf("筛选到的POC数量: %d", len(pocs)))
lib.CheckMultiPoc(req, pocs, Common.PocNum)
} }
// initpoc 初始化POC加载 // initPocs 初始化并加载POC
func initpoc() { func initPocs() {
Common.LogDebug("开始初始化POC") allPocs = make([]*lib.Poc, 0)
if Common.PocPath == "" { if Common.PocPath == "" {
Common.LogDebug("从内置目录加载POC") loadEmbeddedPocs()
// 从嵌入的POC目录加载
entries, err := Pocs.ReadDir("pocs")
if err != nil {
Common.LogError(fmt.Sprintf("加载内置POC失败: %v", err))
return
}
// 加载YAML格式的POC文件
for _, entry := range entries {
filename := entry.Name()
if strings.HasSuffix(filename, ".yaml") || strings.HasSuffix(filename, ".yml") {
if poc, err := lib.LoadPoc(filename, Pocs); err == nil && poc != nil {
AllPocs = append(AllPocs, poc)
} else if err != nil {
}
}
}
Common.LogDebug(fmt.Sprintf("内置POC加载完成共加载 %d 个", len(AllPocs)))
} else { } else {
// 从指定目录加载POC loadExternalPocs(Common.PocPath)
Common.LogSuccess(fmt.Sprintf("从目录加载POC: %s", Common.PocPath))
err := filepath.Walk(Common.PocPath, func(path string, info os.FileInfo, err error) error {
if err != nil || info == nil {
return err
}
if !info.IsDir() && (strings.HasSuffix(path, ".yaml") || strings.HasSuffix(path, ".yml")) {
if poc, err := lib.LoadPocbyPath(path); err == nil && poc != nil {
AllPocs = append(AllPocs, poc)
} else if err != nil {
}
}
return nil
})
if err != nil {
Common.LogError(fmt.Sprintf("加载外部POC失败: %v", err))
}
Common.LogDebug(fmt.Sprintf("外部POC加载完成共加载 %d 个", len(AllPocs)))
} }
} }
// filterPoc 根据POC名称筛选 // loadEmbeddedPocs 加载内置POC
func filterPoc(pocname string) []*lib.Poc { func loadEmbeddedPocs() {
Common.LogDebug(fmt.Sprintf("开始筛选POC筛选条件: %s", pocname)) entries, err := pocsFS.ReadDir("pocs")
if err != nil {
if pocname == "" { Common.LogError(fmt.Sprintf("加载内置POC目录失败: %v", err))
Common.LogDebug(fmt.Sprintf("未指定POC名称返回所有POC: %d 个", len(AllPocs))) return
return AllPocs
} }
// 收集所有POC文件
var pocFiles []string
for _, entry := range entries {
if isPocFile(entry.Name()) {
pocFiles = append(pocFiles, entry.Name())
}
}
// 并发加载POC文件
loadPocsConcurrently(pocFiles, true, "")
}
// loadExternalPocs 从外部路径加载POC
func loadExternalPocs(pocPath string) {
if !directoryExists(pocPath) {
Common.LogError(fmt.Sprintf("POC目录不存在: %s", pocPath))
return
}
// 收集所有POC文件路径
var pocFiles []string
err := filepath.Walk(pocPath, func(path string, info os.FileInfo, err error) error {
if err != nil || info == nil || info.IsDir() {
return nil
}
if isPocFile(info.Name()) {
pocFiles = append(pocFiles, path)
}
return nil
})
if err != nil {
Common.LogError(fmt.Sprintf("遍历POC目录失败: %v", err))
return
}
// 并发加载POC文件
loadPocsConcurrently(pocFiles, false, pocPath)
}
// loadPocsConcurrently 并发加载POC文件
func loadPocsConcurrently(pocFiles []string, isEmbedded bool, pocPath string) {
pocCount := len(pocFiles)
if pocCount == 0 {
return
}
var wg sync.WaitGroup
var mu sync.Mutex
var successCount, failCount int
// 使用信号量控制并发数
semaphore := make(chan struct{}, concurrencyLimit)
for _, file := range pocFiles {
wg.Add(1)
semaphore <- struct{}{} // 获取信号量
go func(filename string) {
defer func() {
<-semaphore // 释放信号量
wg.Done()
}()
var poc *lib.Poc
var err error
// 根据不同的来源加载POC
if isEmbedded {
poc, err = lib.LoadPoc(filename, pocsFS)
} else {
poc, err = lib.LoadPocbyPath(filename)
}
mu.Lock()
defer mu.Unlock()
if err != nil {
failCount++
return
}
if poc != nil {
allPocs = append(allPocs, poc)
successCount++
}
}(file)
}
wg.Wait()
Common.LogSuccess(fmt.Sprintf("POC加载完成: 总共%d个成功%d个失败%d个",
pocCount, successCount, failCount))
}
// directoryExists 检查目录是否存在
func directoryExists(path string) bool {
info, err := os.Stat(path)
return err == nil && info.IsDir()
}
// isPocFile 检查文件是否为POC文件
func isPocFile(filename string) bool {
lowerName := strings.ToLower(filename)
return strings.HasSuffix(lowerName, yamlExt) || strings.HasSuffix(lowerName, ymlExt)
}
// filterPocs 根据POC名称筛选
func filterPocs(pocName string) []*lib.Poc {
if pocName == "" {
return allPocs
}
// 转换为小写以进行不区分大小写的匹配
searchName := strings.ToLower(pocName)
var matchedPocs []*lib.Poc var matchedPocs []*lib.Poc
for _, poc := range AllPocs { for _, poc := range allPocs {
if strings.Contains(poc.Name, pocname) { if poc != nil && strings.Contains(strings.ToLower(poc.Name), searchName) {
matchedPocs = append(matchedPocs, poc) matchedPocs = append(matchedPocs, poc)
} }
} }
Common.LogDebug(fmt.Sprintf("POC筛选完成匹配到 %d 个", len(matchedPocs)))
return matchedPocs return matchedPocs
} }

1
go.mod
View File

@ -27,6 +27,7 @@ require (
github.com/tomatome/grdp v0.0.0-20211231062539-be8adab7eaf3 github.com/tomatome/grdp v0.0.0-20211231062539-be8adab7eaf3
golang.org/x/crypto v0.33.0 golang.org/x/crypto v0.33.0
golang.org/x/net v0.35.0 golang.org/x/net v0.35.0
golang.org/x/sync v0.11.0
golang.org/x/sys v0.30.0 golang.org/x/sys v0.30.0
golang.org/x/text v0.22.0 golang.org/x/text v0.22.0
google.golang.org/genproto v0.0.0-20221027153422-115e99e71e1c google.golang.org/genproto v0.0.0-20221027153422-115e99e71e1c