diff --git a/Core/PortScan.go b/Core/PortScan.go index 1e4ac61..cc7bdb4 100644 --- a/Core/PortScan.go +++ b/Core/PortScan.go @@ -1,262 +1,151 @@ package Core import ( + "context" "fmt" "github.com/shadow1ng/fscan/Common" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" "net" - "sort" "strings" "sync" + "sync/atomic" "time" ) -// Addr 表示待扫描的地址 -type Addr struct { - ip string // IP地址 - port int // 端口号 -} - -// ScanResult 扫描结果 -type ScanResult struct { - Address string // IP地址 - Port int // 端口号 - Service *ServiceInfo // 服务信息 -} - -// PortScan 执行端口扫描 -// hostslist: 待扫描的主机列表 -// ports: 待扫描的端口范围 -// timeout: 超时时间(秒) -// 返回活跃地址列表 -func PortScan(hostslist []string, ports string, timeout int64) []string { - var results []ScanResult - var aliveAddrs []string - var mu sync.Mutex - - // 解析并验证端口列表 - probePorts := Common.ParsePort(ports) - if len(probePorts) == 0 { - Common.LogError(fmt.Sprintf("端口格式错误: %s", ports)) - return aliveAddrs +// EnhancedPortScan 高性能端口扫描函数 +func EnhancedPortScan(hosts []string, ports string, timeout int64) []string { + // 解析端口和排除端口 + portList := Common.ParsePort(ports) + if len(portList) == 0 { + Common.LogError("无效端口: " + ports) + return nil } - // 排除指定端口 - probePorts = excludeNoPorts(probePorts) + exclude := make(map[int]struct{}) + for _, p := range Common.ParsePort(Common.ExcludePorts) { + exclude[p] = struct{}{} + } // 初始化并发控制 - workers := Common.ThreadNum - addrs := make(chan Addr, 100) // 待扫描地址通道 - scanResults := make(chan ScanResult, 100) // 扫描结果通道 - var wg sync.WaitGroup - var workerWg sync.WaitGroup + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + to := time.Duration(timeout) * time.Second + sem := semaphore.NewWeighted(int64(Common.ThreadNum)) + var count int64 + var aliveMap sync.Map + g, ctx := errgroup.WithContext(ctx) - // 启动扫描工作协程 - for i := 0; i < workers; i++ { - workerWg.Add(1) - go func() { - defer workerWg.Done() - for addr := range addrs { - PortConnect(addr, scanResults, timeout, &wg) + // 并发扫描所有目标 + for _, host := range hosts { + for _, port := range portList { + if _, excluded := exclude[port]; excluded { + continue } - }() - } - // 启动结果处理协程 - var resultWg sync.WaitGroup - resultWg.Add(1) - go func() { - defer resultWg.Done() - for result := range scanResults { - mu.Lock() - results = append(results, result) - aliveAddr := fmt.Sprintf("%s:%d", result.Address, result.Port) - aliveAddrs = append(aliveAddrs, aliveAddr) - mu.Unlock() - } - }() + host, port := host, port // 捕获循环变量 + addr := fmt.Sprintf("%s:%d", host, port) - // 分发扫描任务 - for _, port := range probePorts { - for _, host := range hostslist { - wg.Add(1) - addrs <- Addr{host, port} + if err := sem.Acquire(ctx, 1); err != nil { + break + } + + g.Go(func() error { + defer sem.Release(1) + + // 连接测试 + conn, err := net.DialTimeout("tcp", addr, to) + if err != nil { + return nil + } + defer conn.Close() + + // 记录开放端口 + atomic.AddInt64(&count, 1) + aliveMap.Store(addr, struct{}{}) + Common.LogSuccess("端口开放 " + addr) + Common.SaveResult(&Common.ScanResult{ + Time: time.Now(), Type: Common.PORT, Target: host, + Status: "open", Details: map[string]interface{}{"port": port}, + }) + + // 服务识别 + if Common.EnableFingerprint { + if info, err := NewPortInfoScanner(host, port, conn, to).Identify(); err == nil { + // 构建结果详情 + details := map[string]interface{}{"port": port, "service": info.Name} + if info.Version != "" { + details["version"] = info.Version + } + + // 处理额外信息 + for k, v := range info.Extras { + if v == "" { + continue + } + switch k { + case "vendor_product": + details["product"] = v + case "os", "info": + details[k] = v + } + } + if len(info.Banner) > 0 { + details["banner"] = strings.TrimSpace(info.Banner) + } + + // 保存服务结果 + Common.SaveResult(&Common.ScanResult{ + Time: time.Now(), Type: Common.SERVICE, Target: host, + Status: "identified", Details: details, + }) + + // 记录服务信息 + var sb strings.Builder + sb.WriteString("服务识别 " + addr + " => ") + if info.Name != "unknown" { + sb.WriteString("[" + info.Name + "]") + } + if info.Version != "" { + sb.WriteString(" 版本:" + info.Version) + } + + for k, v := range info.Extras { + if v == "" { + continue + } + switch k { + case "vendor_product": + sb.WriteString(" 产品:" + v) + case "os": + sb.WriteString(" 系统:" + v) + case "info": + sb.WriteString(" 信息:" + v) + } + } + + if len(info.Banner) > 0 && len(info.Banner) < 100 { + sb.WriteString(" Banner:[" + strings.TrimSpace(info.Banner) + "]") + } + + Common.LogSuccess(sb.String()) + } + } + + return nil + }) } } - // 等待所有任务完成 - close(addrs) - workerWg.Wait() - wg.Wait() - close(scanResults) - resultWg.Wait() + _ = g.Wait() + // 收集结果 + var aliveAddrs []string + aliveMap.Range(func(key, _ interface{}) bool { + aliveAddrs = append(aliveAddrs, key.(string)) + return true + }) + + Common.LogSuccess(fmt.Sprintf("扫描完成, 发现 %d 个开放端口", count)) return aliveAddrs } - -// PortConnect 执行单个端口连接检测 -// addr: 待检测的地址 -// results: 结果通道 -// timeout: 超时时间 -// wg: 等待组 -func PortConnect(addr Addr, results chan<- ScanResult, timeout int64, wg *sync.WaitGroup) { - defer wg.Done() - - var isOpen bool - var err error - var conn net.Conn - - // 尝试建立TCP连接 - conn, err = Common.WrapperTcpWithTimeout("tcp4", - fmt.Sprintf("%s:%v", addr.ip, addr.port), - time.Duration(timeout)*time.Second) - if err == nil { - defer conn.Close() - isOpen = true - } - - if err != nil || !isOpen { - return - } - - // 记录开放端口 - address := fmt.Sprintf("%s:%d", addr.ip, addr.port) - Common.LogSuccess(fmt.Sprintf("端口开放 %s", address)) - - // 保存端口扫描结果 - portResult := &Common.ScanResult{ - Time: time.Now(), - Type: Common.PORT, - Target: addr.ip, - Status: "open", - Details: map[string]interface{}{ - "port": addr.port, - }, - } - Common.SaveResult(portResult) - - // 构造扫描结果 - result := ScanResult{ - Address: addr.ip, - Port: addr.port, - } - - // 执行服务识别 - if Common.EnableFingerprint && conn != nil { - scanner := NewPortInfoScanner(addr.ip, addr.port, conn, time.Duration(timeout)*time.Second) - if serviceInfo, err := scanner.Identify(); err == nil { - result.Service = serviceInfo - - // 构造服务识别日志 - var logMsg strings.Builder - logMsg.WriteString(fmt.Sprintf("服务识别 %s => ", address)) - - if serviceInfo.Name != "unknown" { - logMsg.WriteString(fmt.Sprintf("[%s]", serviceInfo.Name)) - } - - if serviceInfo.Version != "" { - logMsg.WriteString(fmt.Sprintf(" 版本:%s", serviceInfo.Version)) - } - - // 收集服务详细信息 - details := map[string]interface{}{ - "port": addr.port, - "service": serviceInfo.Name, - } - - // 添加版本信息 - if serviceInfo.Version != "" { - details["version"] = serviceInfo.Version - } - - // 添加产品信息 - if v, ok := serviceInfo.Extras["vendor_product"]; ok && v != "" { - details["product"] = v - logMsg.WriteString(fmt.Sprintf(" 产品:%s", v)) - } - - // 添加操作系统信息 - if v, ok := serviceInfo.Extras["os"]; ok && v != "" { - details["os"] = v - logMsg.WriteString(fmt.Sprintf(" 系统:%s", v)) - } - - // 添加额外信息 - if v, ok := serviceInfo.Extras["info"]; ok && v != "" { - details["info"] = v - logMsg.WriteString(fmt.Sprintf(" 信息:%s", v)) - } - - // 添加Banner信息 - if len(serviceInfo.Banner) > 0 && len(serviceInfo.Banner) < 100 { - details["banner"] = strings.TrimSpace(serviceInfo.Banner) - logMsg.WriteString(fmt.Sprintf(" Banner:[%s]", strings.TrimSpace(serviceInfo.Banner))) - } - - // 保存服务识别结果 - serviceResult := &Common.ScanResult{ - Time: time.Now(), - Type: Common.SERVICE, - Target: addr.ip, - Status: "identified", - Details: details, - } - Common.SaveResult(serviceResult) - - Common.LogSuccess(logMsg.String()) - } - } - - results <- result -} - -// NoPortScan 生成端口列表(不进行扫描) -// hostslist: 主机列表 -// ports: 端口范围 -// 返回地址列表 -func NoPortScan(hostslist []string, ports string) []string { - var AliveAddress []string - - // 解析并排除端口 - probePorts := excludeNoPorts(Common.ParsePort(ports)) - - // 生成地址列表 - for _, port := range probePorts { - for _, host := range hostslist { - address := fmt.Sprintf("%s:%d", host, port) - AliveAddress = append(AliveAddress, address) - } - } - - return AliveAddress -} - -// excludeNoPorts 排除指定的端口 -// ports: 原始端口列表 -// 返回过滤后的端口列表 -func excludeNoPorts(ports []int) []int { - noPorts := Common.ParsePort(Common.ExcludePorts) - if len(noPorts) == 0 { - return ports - } - - // 使用map过滤端口 - temp := make(map[int]struct{}) - for _, port := range ports { - temp[port] = struct{}{} - } - - // 移除需要排除的端口 - for _, port := range noPorts { - delete(temp, port) - } - - // 转换为有序切片 - var newPorts []int - for port := range temp { - newPorts = append(newPorts, port) - } - sort.Ints(newPorts) - - return newPorts -} diff --git a/Core/ServiceScanner.go b/Core/ServiceScanner.go index 5cc9188..cef839e 100644 --- a/Core/ServiceScanner.go +++ b/Core/ServiceScanner.go @@ -92,7 +92,7 @@ func (s *ServiceScanStrategy) discoverAlivePorts(hosts []string) []string { // 根据扫描模式选择端口扫描方式 if len(hosts) > 0 { - alivePorts = PortScan(hosts, Common.Ports, Common.Timeout) + alivePorts = EnhancedPortScan(hosts, Common.Ports, Common.Timeout) Common.LogInfo(fmt.Sprintf("存活端口数量: %d", len(alivePorts))) } diff --git a/Plugins/WebTitle.go b/Plugins/WebTitle.go index 763b284..71b1bb0 100644 --- a/Plugins/WebTitle.go +++ b/Plugins/WebTitle.go @@ -8,9 +8,9 @@ import ( "io" "net" "net/http" - "net/url" "regexp" "strings" + "sync" "time" "unicode/utf8" @@ -20,175 +20,203 @@ import ( "golang.org/x/text/encoding/simplifiedchinese" ) +// 常量定义 +const ( + maxTitleLength = 100 + defaultProtocol = "http" + httpsProtocol = "https" + httpProtocol = "http" + printerFingerPrint = "打印机" + emptyTitle = "\"\"" + noTitleText = "无标题" + + // HTTP相关常量 + httpPort = "80" + httpsPort = "443" + contentEncoding = "Content-Encoding" + gzipEncoding = "gzip" + contentLength = "Content-Length" +) + +// 错误定义 +var ( + ErrNoTitle = fmt.Errorf("无法获取标题") + ErrHTTPClientInit = fmt.Errorf("HTTP客户端未初始化") + ErrReadRespBody = fmt.Errorf("读取响应内容失败") +) + +// 响应结果 +type WebResponse struct { + Url string + StatusCode int + Title string + Length string + Headers map[string]string + RedirectUrl string + Body []byte + Error error +} + +// 协议检测结果 +type ProtocolResult struct { + Protocol string + Success bool +} + // WebTitle 获取Web标题和指纹信息 func WebTitle(info *Common.HostInfo) error { - Common.LogDebug(fmt.Sprintf("开始获取Web标题,初始信息: %+v", info)) - - // 获取网站标题信息 - err, CheckData := GOWebTitle(info) - Common.LogDebug(fmt.Sprintf("GOWebTitle执行完成 - 错误: %v, 检查数据长度: %d", err, len(CheckData))) - - info.Infostr = WebScan.InfoCheck(info.Url, &CheckData) - Common.LogDebug(fmt.Sprintf("信息检查完成,获得信息: %v", info.Infostr)) - - // 检查是否为打印机,避免意外打印 - for _, v := range info.Infostr { - if v == "打印机" { - Common.LogDebug("检测到打印机,停止扫描") - return nil - } + if info == nil { + return fmt.Errorf("主机信息为空") } - // 输出错误信息(如果有) + // 初始化Url + if err := initializeUrl(info); err != nil { + Common.LogError(fmt.Sprintf("初始化Url失败: %v", err)) + return err + } + + // 获取网站标题信息 + checkData, err := fetchWebInfo(info) if err != nil { - errlog := fmt.Sprintf("网站标题 %v %v", info.Url, err) - Common.LogError(errlog) + // 记录错误但继续处理可能获取的数据 + Common.LogError(fmt.Sprintf("获取网站信息失败: %s %v", info.Url, err)) + } + + // 分析指纹 + if len(checkData) > 0 { + info.Infostr = WebScan.InfoCheck(info.Url, &checkData) + + // 检查是否为打印机,避免意外打印 + for _, v := range info.Infostr { + if v == printerFingerPrint { + Common.LogInfo("检测到打印机,停止扫描") + return nil + } + } } return err } -// GOWebTitle 获取网站标题并处理URL,增强错误处理和协议切换 -func GOWebTitle(info *Common.HostInfo) (err error, CheckData []WebScan.CheckDatas) { - Common.LogDebug(fmt.Sprintf("开始处理URL: %s", info.Url)) - - // 如果URL未指定,根据端口生成URL +// 初始化Url:根据主机和端口生成完整Url +func initializeUrl(info *Common.HostInfo) error { if info.Url == "" { - Common.LogDebug("URL为空,根据端口生成URL") + // 根据端口推断Url switch info.Ports { - case "80": - info.Url = fmt.Sprintf("http://%s", info.Host) - case "443": - info.Url = fmt.Sprintf("https://%s", info.Host) + case httpPort: + info.Url = fmt.Sprintf("%s://%s", httpProtocol, info.Host) + case httpsPort: + info.Url = fmt.Sprintf("%s://%s", httpsProtocol, info.Host) default: host := fmt.Sprintf("%s:%s", info.Host, info.Ports) - Common.LogDebug(fmt.Sprintf("正在检测主机协议: %s", host)) - protocol := GetProtocol(host, Common.Timeout) - Common.LogDebug(fmt.Sprintf("检测到协议: %s", protocol)) + protocol, err := detectProtocol(host, Common.Timeout) + if err != nil { + return fmt.Errorf("协议检测失败: %w", err) + } info.Url = fmt.Sprintf("%s://%s:%s", protocol, info.Host, info.Ports) } - } else { - // 处理未指定协议的URL - if !strings.Contains(info.Url, "://") { - Common.LogDebug("URL未包含协议,开始检测") - host := strings.Split(info.Url, "/")[0] - protocol := GetProtocol(host, Common.Timeout) - Common.LogDebug(fmt.Sprintf("检测到协议: %s", protocol)) - info.Url = fmt.Sprintf("%s://%s", protocol, info.Url) - } - } - Common.LogDebug(fmt.Sprintf("协议检测完成后的URL: %s", info.Url)) - - // 记录原始URL协议 - originalProtocol := "http" - if strings.HasPrefix(info.Url, "https://") { - originalProtocol = "https" - } - - // 第一次获取URL - Common.LogDebug("第一次尝试访问URL") - err, result, CheckData := geturl(info, 1, CheckData) - Common.LogDebug(fmt.Sprintf("第一次访问结果 - 错误: %v, 返回信息: %s", err, result)) - - // 如果访问失败并且使用的是HTTPS,尝试降级到HTTP - if err != nil && !strings.Contains(err.Error(), "EOF") { - if originalProtocol == "https" { - Common.LogDebug("HTTPS访问失败,尝试降级到HTTP") - // 替换协议部分 - info.Url = strings.Replace(info.Url, "https://", "http://", 1) - Common.LogDebug(fmt.Sprintf("降级后的URL: %s", info.Url)) - err, result, CheckData = geturl(info, 1, CheckData) - Common.LogDebug(fmt.Sprintf("HTTP降级访问结果 - 错误: %v, 返回信息: %s", err, result)) - - // 如果仍然失败,返回错误 - if err != nil && !strings.Contains(err.Error(), "EOF") { - return - } - } else { - // 如果本来就是HTTP并且失败了,直接返回错误 - return - } - } - - // 处理URL跳转 - if strings.Contains(result, "://") { - Common.LogDebug(fmt.Sprintf("检测到重定向到: %s", result)) - info.Url = result - err, result, CheckData = geturl(info, 3, CheckData) - Common.LogDebug(fmt.Sprintf("重定向请求结果 - 错误: %v, 返回信息: %s", err, result)) + } else if !strings.Contains(info.Url, "://") { + // 处理未指定协议的Url + host := strings.Split(info.Url, "/")[0] + protocol, err := detectProtocol(host, Common.Timeout) if err != nil { - // 如果重定向跟踪失败,尝试降级协议 - if strings.HasPrefix(info.Url, "https://") { - Common.LogDebug("重定向HTTPS访问失败,尝试降级到HTTP") - info.Url = strings.Replace(info.Url, "https://", "http://", 1) - err, result, CheckData = geturl(info, 3, CheckData) - Common.LogDebug(fmt.Sprintf("重定向降级访问结果 - 错误: %v, 返回信息: %s", err, result)) - } - - if err != nil { - return - } + return fmt.Errorf("协议检测失败: %w", err) } + info.Url = fmt.Sprintf("%s://%s", protocol, info.Url) } - // 处理HTTP到HTTPS的升级提示 - if result == "https" && !strings.HasPrefix(info.Url, "https://") { - Common.LogDebug("正在升级到HTTPS") - info.Url = strings.Replace(info.Url, "http://", "https://", 1) - Common.LogDebug(fmt.Sprintf("升级后的URL: %s", info.Url)) - err, result, CheckData = geturl(info, 1, CheckData) - Common.LogDebug(fmt.Sprintf("HTTPS升级访问结果 - 错误: %v, 返回信息: %s", err, result)) - - // 如果HTTPS升级后访问失败,回退到HTTP - if err != nil && !strings.Contains(err.Error(), "EOF") { - Common.LogDebug("HTTPS升级访问失败,回退到HTTP") - info.Url = strings.Replace(info.Url, "https://", "http://", 1) - err, result, CheckData = geturl(info, 1, CheckData) - Common.LogDebug(fmt.Sprintf("回退到HTTP访问结果 - 错误: %v, 返回信息: %s", err, result)) - } - - // 处理升级后的跳转 - if strings.Contains(result, "://") { - Common.LogDebug(fmt.Sprintf("协议升级后发现重定向到: %s", result)) - info.Url = result - err, _, CheckData = geturl(info, 3, CheckData) - if err != nil { - // 如果重定向跟踪失败,再次尝试降级 - if strings.HasPrefix(info.Url, "https://") { - Common.LogDebug("升级后重定向HTTPS访问失败,尝试降级到HTTP") - info.Url = strings.Replace(info.Url, "https://", "http://", 1) - err, _, CheckData = geturl(info, 3, CheckData) - } - } - } - } - - Common.LogDebug(fmt.Sprintf("GOWebTitle执行完成 - 错误: %v", err)) - return + return nil } -func geturl(info *Common.HostInfo, flag int, CheckData []WebScan.CheckDatas) (error, string, []WebScan.CheckDatas) { - Common.LogDebug(fmt.Sprintf("geturl开始执行 - URL: %s, 标志位: %d", info.Url, flag)) +// 获取Web信息:标题、指纹等 +func fetchWebInfo(info *Common.HostInfo) ([]WebScan.CheckDatas, error) { + var checkData []WebScan.CheckDatas - // 处理目标URL - Url := info.Url - if flag == 2 { - Common.LogDebug("处理favicon.ico URL") - URL, err := url.Parse(Url) - if err == nil { - Url = fmt.Sprintf("%s://%s/favicon.ico", URL.Scheme, URL.Host) + // 记录原始Url协议 + originalUrl := info.Url + isHTTPS := strings.HasPrefix(info.Url, "https://") + + // 第一次尝试访问Url + resp, err := fetchUrlWithRetry(info, false, &checkData) + + // 处理不同的错误情况 + if err != nil { + // 如果是HTTPS并失败,尝试降级到HTTP + if isHTTPS { + info.Url = strings.Replace(info.Url, "https://", "http://", 1) + resp, err = fetchUrlWithRetry(info, false, &checkData) + + // 如果HTTP也失败,恢复原始Url并返回错误 + if err != nil { + info.Url = originalUrl + return checkData, err + } } else { - Url += "/favicon.ico" + return checkData, err } - Common.LogDebug(fmt.Sprintf("favicon URL: %s", Url)) } - // 创建HTTP请求 - Common.LogDebug("开始创建HTTP请求") - req, err := http.NewRequest("GET", Url, nil) + // 处理重定向 + if resp != nil && resp.RedirectUrl != "" { + info.Url = resp.RedirectUrl + resp, err = fetchUrlWithRetry(info, true, &checkData) + + // 如果重定向后失败,尝试降级协议 + if err != nil && strings.HasPrefix(info.Url, "https://") { + info.Url = strings.Replace(info.Url, "https://", "http://", 1) + resp, err = fetchUrlWithRetry(info, true, &checkData) + } + } + + // 处理需要升级到HTTPS的情况 + if resp != nil && resp.StatusCode == 400 && !strings.HasPrefix(info.Url, "https://") { + info.Url = strings.Replace(info.Url, "http://", "https://", 1) + resp, err = fetchUrlWithRetry(info, false, &checkData) + + // 如果HTTPS升级失败,回退到HTTP + if err != nil { + info.Url = strings.Replace(info.Url, "https://", "http://", 1) + resp, err = fetchUrlWithRetry(info, false, &checkData) + } + + // 处理升级后的重定向 + if resp != nil && resp.RedirectUrl != "" { + info.Url = resp.RedirectUrl + resp, err = fetchUrlWithRetry(info, true, &checkData) + } + } + + return checkData, err +} + +// 尝试获取Url,支持重试 +func fetchUrlWithRetry(info *Common.HostInfo, followRedirect bool, checkData *[]WebScan.CheckDatas) (*WebResponse, error) { + // 获取页面内容 + resp, err := fetchUrl(info.Url, followRedirect) if err != nil { - Common.LogDebug(fmt.Sprintf("创建HTTP请求失败: %v", err)) - return err, "", CheckData + return nil, err + } + + // 保存检查数据 + if resp.Body != nil && len(resp.Body) > 0 { + headers := fmt.Sprintf("%v", resp.Headers) + *checkData = append(*checkData, WebScan.CheckDatas{resp.Body, headers}) + } + + // 保存扫描结果 + if resp.StatusCode > 0 { + saveWebResult(info, resp) + } + + return resp, nil +} + +// 抓取Url内容 +func fetchUrl(targetUrl string, followRedirect bool) (*WebResponse, error) { + // 创建HTTP请求 + req, err := http.NewRequest("GET", targetUrl, nil) + if err != nil { + return nil, fmt.Errorf("创建HTTP请求失败: %w", err) } // 设置请求头 @@ -199,378 +227,327 @@ func geturl(info *Common.HostInfo, flag int, CheckData []WebScan.CheckDatas) (er req.Header.Set("Cookie", Common.Cookie) } req.Header.Set("Connection", "close") - Common.LogDebug("已设置请求头") // 选择HTTP客户端 var client *http.Client - if flag == 1 { - client = lib.ClientNoRedirect - Common.LogDebug("使用不跟随重定向的客户端") - } else { + if followRedirect { client = lib.Client - Common.LogDebug("使用普通客户端") + } else { + client = lib.ClientNoRedirect } - // 检查客户端是否为空 if client == nil { - Common.LogDebug("错误: HTTP客户端为空") - return fmt.Errorf("HTTP客户端未初始化"), "", CheckData + return nil, ErrHTTPClientInit } // 发送请求 - Common.LogDebug("开始发送HTTP请求") resp, err := client.Do(req) if err != nil { - Common.LogDebug(fmt.Sprintf("HTTP请求失败: %v", err)) - return err, "https", CheckData + // 特殊处理SSL/TLS相关错误 + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, "tls") || strings.Contains(errMsg, "ssl") || + strings.Contains(errMsg, "handshake") || strings.Contains(errMsg, "certificate") { + return &WebResponse{Error: err}, nil + } + return nil, err } defer resp.Body.Close() - Common.LogDebug(fmt.Sprintf("收到HTTP响应,状态码: %d", resp.StatusCode)) + + // 准备响应结果 + result := &WebResponse{ + Url: req.URL.String(), + StatusCode: resp.StatusCode, + Headers: make(map[string]string), + } + + // 提取响应头 + for k, v := range resp.Header { + if len(v) > 0 { + result.Headers[k] = v[0] + } + } + + // 获取内容长度 + result.Length = resp.Header.Get(contentLength) + + // 检查重定向 + redirectUrl, err := resp.Location() + if err == nil { + result.RedirectUrl = redirectUrl.String() + } // 读取响应内容 - body, err := getRespBody(resp) + body, err := readResponseBody(resp) if err != nil { - Common.LogDebug(fmt.Sprintf("读取响应内容失败: %v", err)) - return err, "https", CheckData + return result, fmt.Errorf("读取响应内容失败: %w", err) } - Common.LogDebug(fmt.Sprintf("成功读取响应内容,长度: %d", len(body))) + result.Body = body - // 保存检查数据 - CheckData = append(CheckData, WebScan.CheckDatas{body, fmt.Sprintf("%s", resp.Header)}) - Common.LogDebug("已保存检查数据") + // 提取标题 + if !utf8.Valid(body) { + body, _ = simplifiedchinese.GBK.NewDecoder().Bytes(body) + } + result.Title = extractTitle(body) - // 处理非favicon请求 - var reurl string - if flag != 2 { - // 处理编码 - if !utf8.Valid(body) { - body, _ = simplifiedchinese.GBK.NewDecoder().Bytes(body) - } - - // 获取页面信息 - title := gettitle(body) - length := resp.Header.Get("Content-Length") - if length == "" { - length = fmt.Sprintf("%v", len(body)) - } - - // 收集服务器信息 - serverInfo := make(map[string]interface{}) - serverInfo["title"] = title - serverInfo["length"] = length - serverInfo["status_code"] = resp.StatusCode - - // 收集响应头信息 - for k, v := range resp.Header { - if len(v) > 0 { - serverInfo[strings.ToLower(k)] = v[0] - } - } - - // 检查重定向 - redirURL, err1 := resp.Location() - if err1 == nil { - reurl = redirURL.String() - serverInfo["redirect_url"] = reurl - } - - // 处理指纹信息 - 添加调试日志 - Common.LogDebug(fmt.Sprintf("保存结果前的指纹信息: %v", info.Infostr)) - - // 处理空指纹情况 - fingerprints := info.Infostr - if len(fingerprints) == 1 && fingerprints[0] == "" { - // 如果是只包含空字符串的数组,替换为空数组 - fingerprints = []string{} - Common.LogDebug("检测到空指纹,已转换为空数组") - } - - // 保存扫描结果 - result := &Common.ScanResult{ - Time: time.Now(), - Type: Common.SERVICE, - Target: info.Host, - Status: "identified", - Details: map[string]interface{}{ - "port": info.Ports, - "service": "http", - "title": title, - "url": resp.Request.URL.String(), - "status_code": resp.StatusCode, - "length": length, - "server_info": serverInfo, - "fingerprints": fingerprints, // 使用处理过的指纹信息 - }, - } - Common.SaveResult(result) - Common.LogDebug(fmt.Sprintf("已保存结果,指纹信息: %v", fingerprints)) - - // 输出控制台日志 - logMsg := fmt.Sprintf("网站标题 %-25v 状态码:%-3v 长度:%-6v 标题:%v", - resp.Request.URL, resp.StatusCode, length, title) - if reurl != "" { - logMsg += fmt.Sprintf(" 重定向地址: %s", reurl) - } - // 添加指纹信息到控制台日志 - if len(fingerprints) > 0 { - logMsg += fmt.Sprintf(" 指纹:%v", fingerprints) - } - Common.LogSuccess(logMsg) + if result.Length == "" { + result.Length = fmt.Sprintf("%d", len(body)) } - // 返回结果 - if reurl != "" { - Common.LogDebug(fmt.Sprintf("返回重定向URL: %s", reurl)) - return nil, reurl, CheckData - } - if resp.StatusCode == 400 && !strings.HasPrefix(info.Url, "https") { - Common.LogDebug("返回HTTPS升级标志") - return nil, "https", CheckData - } - Common.LogDebug("geturl执行完成,无特殊返回") - return nil, "", CheckData + return result, nil } -// getRespBody 读取HTTP响应体内容 -func getRespBody(oResp *http.Response) ([]byte, error) { - Common.LogDebug("开始读取响应体内容") +// 读取HTTP响应体内容 +func readResponseBody(resp *http.Response) ([]byte, error) { var body []byte + var reader io.Reader = resp.Body // 处理gzip压缩的响应 - if oResp.Header.Get("Content-Encoding") == "gzip" { - Common.LogDebug("检测到gzip压缩,开始解压") - gr, err := gzip.NewReader(oResp.Body) + if resp.Header.Get(contentEncoding) == gzipEncoding { + gr, err := gzip.NewReader(resp.Body) if err != nil { - Common.LogDebug(fmt.Sprintf("创建gzip解压器失败: %v", err)) - return nil, err + return nil, fmt.Errorf("创建gzip解压器失败: %w", err) } defer gr.Close() - - // 循环读取解压内容 - for { - buf := make([]byte, 1024) - n, err := gr.Read(buf) - if err != nil && err != io.EOF { - Common.LogDebug(fmt.Sprintf("读取压缩内容失败: %v", err)) - return nil, err - } - if n == 0 { - break - } - body = append(body, buf...) - } - Common.LogDebug(fmt.Sprintf("gzip解压完成,内容长度: %d", len(body))) - } else { - // 直接读取未压缩的响应 - Common.LogDebug("读取未压缩的响应内容") - raw, err := io.ReadAll(oResp.Body) - if err != nil { - Common.LogDebug(fmt.Sprintf("读取响应内容失败: %v", err)) - return nil, err - } - body = raw - Common.LogDebug(fmt.Sprintf("读取完成,内容长度: %d", len(body))) + reader = gr } + + // 读取内容 + body, err := io.ReadAll(reader) + if err != nil { + return nil, fmt.Errorf("读取响应内容失败: %w", err) + } + return body, nil } -// gettitle 从HTML内容中提取网页标题 -func gettitle(body []byte) (title string) { - Common.LogDebug("开始提取网页标题") - +// 提取网页标题 +func extractTitle(body []byte) string { // 使用正则表达式匹配title标签内容 re := regexp.MustCompile("(?ims)(.*?)") find := re.FindSubmatch(body) if len(find) > 1 { - title = string(find[1]) - Common.LogDebug(fmt.Sprintf("找到原始标题: %s", title)) + title := string(find[1]) // 清理标题内容 - title = strings.TrimSpace(title) // 去除首尾空格 - title = strings.Replace(title, "\n", "", -1) // 去除换行 - title = strings.Replace(title, "\r", "", -1) // 去除回车 - title = strings.Replace(title, " ", " ", -1) // 替换HTML空格 + title = strings.TrimSpace(title) + title = strings.Replace(title, "\n", "", -1) + title = strings.Replace(title, "\r", "", -1) + title = strings.Replace(title, " ", " ", -1) // 截断过长的标题 - if len(title) > 100 { - Common.LogDebug("标题超过100字符,进行截断") - title = title[:100] + if len(title) > maxTitleLength { + title = title[:maxTitleLength] } // 处理空标题 if title == "" { - Common.LogDebug("标题为空,使用双引号代替") - title = "\"\"" + return emptyTitle } - } else { - Common.LogDebug("未找到标题标签") - title = "无标题" + + return title } - Common.LogDebug(fmt.Sprintf("最终标题: %s", title)) - return + + return noTitleText } -// GetProtocol 检测目标主机的协议类型(HTTP/HTTPS),优先返回可用的协议 -func GetProtocol(host string, Timeout int64) (protocol string) { - Common.LogDebug(fmt.Sprintf("开始检测主机协议 - 主机: %s, 超时: %d秒", host, Timeout)) +// 保存Web扫描结果 +func saveWebResult(info *Common.HostInfo, resp *WebResponse) { + // 处理指纹信息 + fingerprints := info.Infostr + if len(fingerprints) == 1 && fingerprints[0] == "" { + fingerprints = []string{} + } - // 默认使用http协议 - protocol = "http" + // 准备服务器信息 + serverInfo := make(map[string]interface{}) + serverInfo["title"] = resp.Title + serverInfo["length"] = resp.Length + serverInfo["status_code"] = resp.StatusCode - timeoutDuration := time.Duration(Timeout) * time.Second + // 添加响应头信息 + for k, v := range resp.Headers { + serverInfo[strings.ToLower(k)] = v + } + + // 添加重定向信息 + if resp.RedirectUrl != "" { + serverInfo["redirect_Url"] = resp.RedirectUrl + } + + // 保存扫描结果 + result := &Common.ScanResult{ + Time: time.Now(), + Type: Common.SERVICE, + Target: info.Host, + Status: "identified", + Details: map[string]interface{}{ + "port": info.Ports, + "service": "http", + "title": resp.Title, + "Url": resp.Url, + "status_code": resp.StatusCode, + "length": resp.Length, + "server_info": serverInfo, + "fingerprints": fingerprints, + }, + } + Common.SaveResult(result) + + // 输出控制台日志 + logMsg := fmt.Sprintf("网站标题 %-25v 状态码:%-3v 长度:%-6v 标题:%v", + resp.Url, resp.StatusCode, resp.Length, resp.Title) + + if resp.RedirectUrl != "" { + logMsg += fmt.Sprintf(" 重定向地址: %s", resp.RedirectUrl) + } + + if len(fingerprints) > 0 { + logMsg += fmt.Sprintf(" 指纹:%v", fingerprints) + } + + Common.LogSuccess(logMsg) +} + +// 检测目标主机的协议类型(HTTP/HTTPS) +func detectProtocol(host string, timeout int64) (string, error) { + // 根据标准端口快速判断协议 + if strings.HasSuffix(host, ":"+httpPort) { + return httpProtocol, nil + } else if strings.HasSuffix(host, ":"+httpsPort) { + return httpsProtocol, nil + } + + timeoutDuration := time.Duration(timeout) * time.Second ctx, cancel := context.WithTimeout(context.Background(), timeoutDuration) defer cancel() - // 1. 根据标准端口快速判断协议 - if strings.HasSuffix(host, ":80") { - Common.LogDebug("检测到标准HTTP端口,使用HTTP协议") - return "http" - } else if strings.HasSuffix(host, ":443") { - Common.LogDebug("检测到标准HTTPS端口,使用HTTPS协议") - return "https" - } - - // 2. 并发检测HTTP和HTTPS - type protocolResult struct { - name string - success bool - } - - resultChan := make(chan protocolResult, 2) - singleTimeout := timeoutDuration / 2 // 每个协议检测的超时时间减半 + // 并发检测HTTP和HTTPS + resultChan := make(chan ProtocolResult, 2) + wg := sync.WaitGroup{} + wg.Add(2) // 检测HTTPS go func() { - Common.LogDebug("开始检测HTTPS协议") - tlsConfig := &tls.Config{ - InsecureSkipVerify: true, - MinVersion: tls.VersionTLS10, + defer wg.Done() + success := checkHTTPS(host, timeoutDuration/2) + select { + case resultChan <- ProtocolResult{httpsProtocol, success}: + case <-ctx.Done(): } - - dialer := &net.Dialer{ - Timeout: singleTimeout, - } - - conn, err := tls.DialWithDialer(dialer, "tcp", host, tlsConfig) - if err == nil { - Common.LogDebug("HTTPS连接成功") - conn.Close() - resultChan <- protocolResult{"https", true} - return - } - - // 分析TLS错误 - if err != nil { - errMsg := strings.ToLower(err.Error()) - // 这些错误可能表明服务器确实支持TLS,但有其他问题 - if strings.Contains(errMsg, "handshake failure") || - strings.Contains(errMsg, "certificate") || - strings.Contains(errMsg, "tls") || - strings.Contains(errMsg, "x509") || - strings.Contains(errMsg, "secure") { - Common.LogDebug(fmt.Sprintf("TLS握手有错误但可能是HTTPS协议: %v", err)) - resultChan <- protocolResult{"https", true} - return - } - Common.LogDebug(fmt.Sprintf("HTTPS连接失败: %v", err)) - } - resultChan <- protocolResult{"https", false} }() // 检测HTTP go func() { - Common.LogDebug("开始检测HTTP协议") - req, err := http.NewRequestWithContext(ctx, "HEAD", fmt.Sprintf("http://%s", host), nil) - if err != nil { - Common.LogDebug(fmt.Sprintf("创建HTTP请求失败: %v", err)) - resultChan <- protocolResult{"http", false} - return + defer wg.Done() + success := checkHTTP(ctx, host, timeoutDuration/2) + select { + case resultChan <- ProtocolResult{httpProtocol, success}: + case <-ctx.Done(): } - - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, - DialContext: (&net.Dialer{ - Timeout: singleTimeout, - }).DialContext, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse // 不跟随重定向 - }, - Timeout: singleTimeout, - } - - resp, err := client.Do(req) - if err == nil { - resp.Body.Close() - Common.LogDebug(fmt.Sprintf("HTTP连接成功,状态码: %d", resp.StatusCode)) - resultChan <- protocolResult{"http", true} - return - } - - Common.LogDebug(fmt.Sprintf("标准HTTP请求失败: %v,尝试原始TCP连接", err)) - - // 尝试原始TCP连接和简单HTTP请求 - netConn, err := net.DialTimeout("tcp", host, singleTimeout) - if err == nil { - defer netConn.Close() - netConn.SetDeadline(time.Now().Add(singleTimeout)) - - // 发送简单HTTP请求 - _, err = netConn.Write([]byte("HEAD / HTTP/1.0\r\nHost: " + host + "\r\n\r\n")) - if err == nil { - // 读取响应 - buf := make([]byte, 1024) - netConn.SetDeadline(time.Now().Add(singleTimeout)) - n, err := netConn.Read(buf) - if err == nil && n > 0 { - response := string(buf[:n]) - if strings.Contains(response, "HTTP/") { - Common.LogDebug("通过原始TCP连接确认HTTP协议") - resultChan <- protocolResult{"http", true} - return - } - } - } - Common.LogDebug("原始TCP连接成功但HTTP响应无效") - } else { - Common.LogDebug(fmt.Sprintf("原始TCP连接失败: %v", err)) - } - - resultChan <- protocolResult{"http", false} }() - // 3. 收集结果并决定使用哪种协议 - var httpsSuccess, httpSuccess bool + // 确保所有goroutine正常退出 + go func() { + wg.Wait() + close(resultChan) + }() - // 等待两个goroutine返回结果或超时 - for i := 0; i < 2; i++ { - select { - case result := <-resultChan: - if result.name == "https" { - httpsSuccess = result.success - Common.LogDebug(fmt.Sprintf("HTTPS检测结果: %v", httpsSuccess)) - } else if result.name == "http" { - httpSuccess = result.success - Common.LogDebug(fmt.Sprintf("HTTP检测结果: %v", httpSuccess)) - } - case <-ctx.Done(): - Common.LogDebug("协议检测超时") - break + // 收集结果 + var httpsResult, httpResult *ProtocolResult + + for result := range resultChan { + if result.Protocol == httpsProtocol { + r := result + httpsResult = &r + } else if result.Protocol == httpProtocol { + r := result + httpResult = &r } } - // 4. 决定使用哪种协议 - 优先使用HTTPS,如果HTTPS不可用则使用HTTP - if httpsSuccess { - Common.LogDebug("选择使用HTTPS协议") - return "https" - } else if httpSuccess { - Common.LogDebug("选择使用HTTP协议") - return "http" + // 决定使用哪种协议 - 优先使用HTTPS + if httpsResult != nil && httpsResult.Success { + return httpsProtocol, nil + } else if httpResult != nil && httpResult.Success { + return httpProtocol, nil } - // 5. 如果两种协议都无法确认,保持默认值 - Common.LogDebug(fmt.Sprintf("无法确定协议,使用默认协议: %s", protocol)) - return + // 默认使用HTTP + return defaultProtocol, nil +} + +// 检测HTTPS协议 +func checkHTTPS(host string, timeout time.Duration) bool { + tlsConfig := &tls.Config{ + InsecureSkipVerify: true, + MinVersion: tls.VersionTLS10, + } + + dialer := &net.Dialer{ + Timeout: timeout, + } + + conn, err := tls.DialWithDialer(dialer, "tcp", host, tlsConfig) + if err == nil { + conn.Close() + return true + } + + // 分析TLS错误,某些错误可能表明服务器支持TLS但有其他问题 + errMsg := strings.ToLower(err.Error()) + return strings.Contains(errMsg, "handshake failure") || + strings.Contains(errMsg, "certificate") || + strings.Contains(errMsg, "tls") || + strings.Contains(errMsg, "x509") || + strings.Contains(errMsg, "secure") +} + +// 检测HTTP协议 +func checkHTTP(ctx context.Context, host string, timeout time.Duration) bool { + req, err := http.NewRequestWithContext(ctx, "HEAD", fmt.Sprintf("http://%s", host), nil) + if err != nil { + return false + } + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + DialContext: (&net.Dialer{ + Timeout: timeout, + }).DialContext, + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse // 不跟随重定向 + }, + Timeout: timeout, + } + + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() + return true + } + + // 尝试原始TCP连接和简单HTTP请求 + netConn, err := net.DialTimeout("tcp", host, timeout) + if err == nil { + defer netConn.Close() + netConn.SetDeadline(time.Now().Add(timeout)) + + // 发送简单HTTP请求 + _, err = netConn.Write([]byte("HEAD / HTTP/1.0\r\nHost: " + host + "\r\n\r\n")) + if err == nil { + // 读取响应 + buf := make([]byte, 1024) + netConn.SetDeadline(time.Now().Add(timeout)) + n, err := netConn.Read(buf) + if err == nil && n > 0 { + response := string(buf[:n]) + return strings.Contains(response, "HTTP/") + } + } + } + + return false } diff --git a/WebScan/WebScan.go b/WebScan/WebScan.go index 75287e0..daa4945 100644 --- a/WebScan/WebScan.go +++ b/WebScan/WebScan.go @@ -1,90 +1,171 @@ package WebScan import ( + "context" "embed" + "errors" "fmt" - "github.com/shadow1ng/fscan/Common" - "github.com/shadow1ng/fscan/WebScan/lib" "net/http" "net/url" "os" "path/filepath" "strings" "sync" + "time" + + "github.com/shadow1ng/fscan/Common" + "github.com/shadow1ng/fscan/WebScan/lib" +) + +// 常量定义 +const ( + protocolHTTP = "http://" + protocolHTTPS = "https://" + yamlExt = ".yaml" + ymlExt = ".yml" + defaultTimeout = 30 * time.Second + concurrencyLimit = 10 // 并发加载POC的限制 +) + +// 错误定义 +var ( + ErrInvalidURL = errors.New("无效的URL格式") + ErrEmptyTarget = errors.New("目标URL为空") + ErrPocNotFound = errors.New("未找到匹配的POC") + ErrPocLoadFailed = errors.New("POC加载失败") ) //go:embed pocs -var Pocs embed.FS -var once sync.Once -var AllPocs []*lib.Poc +var pocsFS embed.FS +var ( + once sync.Once + allPocs []*lib.Poc +) // WebScan 执行Web漏洞扫描 func WebScan(info *Common.HostInfo) { - once.Do(initpoc) + // 初始化POC + once.Do(initPocs) - var pocinfo = Common.Pocinfo - - // 自动构建URL - if info.Url == "" { - info.Url = fmt.Sprintf("http://%s:%s", info.Host, info.Ports) + // 验证输入 + if info == nil { + Common.LogError("无效的扫描目标") + return } - urlParts := strings.Split(info.Url, "/") - - // 检查切片长度并构建目标URL - if len(urlParts) >= 3 { - pocinfo.Target = strings.Join(urlParts[:3], "/") - } else { - pocinfo.Target = info.Url + if len(allPocs) == 0 { + Common.LogError("POC加载失败,无法执行扫描") + return } - Common.LogDebug(fmt.Sprintf("扫描目标: %s", pocinfo.Target)) + // 构建目标URL + target, err := buildTargetURL(info) + if err != nil { + Common.LogError(fmt.Sprintf("构建目标URL失败: %v", err)) + return + } - // 如果是直接调用WebPoc(没有指定pocName),执行所有POC - if pocinfo.PocName == "" && len(info.Infostr) == 0 { - Common.LogDebug("直接调用WebPoc,执行所有POC") - Execute(pocinfo) - } else { - // 根据指纹信息选择性执行POC - if len(info.Infostr) > 0 { - for _, infostr := range info.Infostr { - pocinfo.PocName = lib.CheckInfoPoc(infostr) - if pocinfo.PocName != "" { - Common.LogDebug(fmt.Sprintf("根据指纹 %s 执行对应POC", infostr)) - Execute(pocinfo) - } - } - } else if pocinfo.PocName != "" { - // 指定了特定的POC - Common.LogDebug(fmt.Sprintf("执行指定POC: %s", pocinfo.PocName)) - Execute(pocinfo) - } + // 使用带超时的上下文 + ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout) + defer cancel() + + // 根据扫描策略执行POC + if Common.Pocinfo.PocName == "" && len(info.Infostr) == 0 { + // 执行所有POC + executePOCs(ctx, Common.PocInfo{Target: target}) + } else if len(info.Infostr) > 0 { + // 基于指纹信息执行POC + scanByFingerprints(ctx, target, info.Infostr) + } else if Common.Pocinfo.PocName != "" { + // 基于指定POC名称执行 + executePOCs(ctx, Common.PocInfo{Target: target, PocName: Common.Pocinfo.PocName}) } } -// Execute 执行具体的POC检测 -func Execute(PocInfo Common.PocInfo) { - Common.LogDebug(fmt.Sprintf("开始执行POC检测,目标: %s", PocInfo.Target)) +// buildTargetURL 构建规范的目标URL +func buildTargetURL(info *Common.HostInfo) (string, error) { + // 自动构建URL + if info.Url == "" { + info.Url = fmt.Sprintf("%s%s:%s", protocolHTTP, info.Host, info.Ports) + } else if !hasProtocolPrefix(info.Url) { + info.Url = protocolHTTP + info.Url + } + + // 解析URL以提取基础部分 + parsedURL, err := url.Parse(info.Url) + if err != nil { + return "", fmt.Errorf("%w: %v", ErrInvalidURL, err) + } + + return fmt.Sprintf("%s://%s", parsedURL.Scheme, parsedURL.Host), nil +} + +// hasProtocolPrefix 检查URL是否包含协议前缀 +func hasProtocolPrefix(urlStr string) bool { + return strings.HasPrefix(urlStr, protocolHTTP) || strings.HasPrefix(urlStr, protocolHTTPS) +} + +// scanByFingerprints 根据指纹执行POC +func scanByFingerprints(ctx context.Context, target string, fingerprints []string) { + for _, fingerprint := range fingerprints { + if fingerprint == "" { + continue + } + + pocName := lib.CheckInfoPoc(fingerprint) + if pocName == "" { + continue + } + + executePOCs(ctx, Common.PocInfo{Target: target, PocName: pocName}) + } +} + +// executePOCs 执行POC检测 +func executePOCs(ctx context.Context, pocInfo Common.PocInfo) { + // 验证目标 + if pocInfo.Target == "" { + Common.LogError(ErrEmptyTarget.Error()) + return + } // 确保URL格式正确 - if !strings.HasPrefix(PocInfo.Target, "http://") && !strings.HasPrefix(PocInfo.Target, "https://") { - PocInfo.Target = "http://" + PocInfo.Target + if !hasProtocolPrefix(pocInfo.Target) { + pocInfo.Target = protocolHTTP + pocInfo.Target } - // 验证URL格式 - _, err := url.Parse(PocInfo.Target) + // 验证URL + _, err := url.Parse(pocInfo.Target) if err != nil { - Common.LogError(fmt.Sprintf("无效的URL格式 %v: %v", PocInfo.Target, err)) + Common.LogError(fmt.Sprintf("%v %s: %v", ErrInvalidURL, pocInfo.Target, err)) return } - // 创建基础HTTP请求 - req, err := http.NewRequest("GET", PocInfo.Target, nil) + // 创建基础请求 + req, err := createBaseRequest(ctx, pocInfo.Target) if err != nil { - Common.LogError(fmt.Sprintf("初始化请求失败 %v: %v", PocInfo.Target, err)) + Common.LogError(fmt.Sprintf("创建HTTP请求失败: %v", err)) return } + // 筛选POC + matchedPocs := filterPocs(pocInfo.PocName) + if len(matchedPocs) == 0 { + Common.LogDebug(fmt.Sprintf("%v: %s", ErrPocNotFound, pocInfo.PocName)) + return + } + + // 执行POC检测 + lib.CheckMultiPoc(req, matchedPocs, Common.PocNum) +} + +// createBaseRequest 创建带上下文的HTTP请求 +func createBaseRequest(ctx context.Context, target string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, "GET", target, nil) + if err != nil { + return nil, err + } + // 设置请求头 req.Header.Set("User-agent", Common.UserAgent) req.Header.Set("Accept", Common.Accept) @@ -93,75 +174,150 @@ func Execute(PocInfo Common.PocInfo) { req.Header.Set("Cookie", Common.Cookie) } - // 根据名称筛选POC并执行 - pocs := filterPoc(PocInfo.PocName) - Common.LogDebug(fmt.Sprintf("筛选到的POC数量: %d", len(pocs))) - lib.CheckMultiPoc(req, pocs, Common.PocNum) + return req, nil } -// initpoc 初始化POC加载 -func initpoc() { - Common.LogDebug("开始初始化POC") +// initPocs 初始化并加载POC +func initPocs() { + allPocs = make([]*lib.Poc, 0) if Common.PocPath == "" { - Common.LogDebug("从内置目录加载POC") - // 从嵌入的POC目录加载 - entries, err := Pocs.ReadDir("pocs") - if err != nil { - Common.LogError(fmt.Sprintf("加载内置POC失败: %v", err)) - return - } - - // 加载YAML格式的POC文件 - for _, entry := range entries { - filename := entry.Name() - if strings.HasSuffix(filename, ".yaml") || strings.HasSuffix(filename, ".yml") { - if poc, err := lib.LoadPoc(filename, Pocs); err == nil && poc != nil { - AllPocs = append(AllPocs, poc) - } else if err != nil { - } - } - } - Common.LogDebug(fmt.Sprintf("内置POC加载完成,共加载 %d 个", len(AllPocs))) + loadEmbeddedPocs() } else { - // 从指定目录加载POC - Common.LogSuccess(fmt.Sprintf("从目录加载POC: %s", Common.PocPath)) - err := filepath.Walk(Common.PocPath, func(path string, info os.FileInfo, err error) error { - if err != nil || info == nil { - return err - } - - if !info.IsDir() && (strings.HasSuffix(path, ".yaml") || strings.HasSuffix(path, ".yml")) { - if poc, err := lib.LoadPocbyPath(path); err == nil && poc != nil { - AllPocs = append(AllPocs, poc) - } else if err != nil { - } - } - return nil - }) - - if err != nil { - Common.LogError(fmt.Sprintf("加载外部POC失败: %v", err)) - } - Common.LogDebug(fmt.Sprintf("外部POC加载完成,共加载 %d 个", len(AllPocs))) + loadExternalPocs(Common.PocPath) } } -// filterPoc 根据POC名称筛选 -func filterPoc(pocname string) []*lib.Poc { - Common.LogDebug(fmt.Sprintf("开始筛选POC,筛选条件: %s", pocname)) - - if pocname == "" { - Common.LogDebug(fmt.Sprintf("未指定POC名称,返回所有POC: %d 个", len(AllPocs))) - return AllPocs +// loadEmbeddedPocs 加载内置POC +func loadEmbeddedPocs() { + entries, err := pocsFS.ReadDir("pocs") + if err != nil { + Common.LogError(fmt.Sprintf("加载内置POC目录失败: %v", err)) + return } + // 收集所有POC文件 + var pocFiles []string + for _, entry := range entries { + if isPocFile(entry.Name()) { + pocFiles = append(pocFiles, entry.Name()) + } + } + + // 并发加载POC文件 + loadPocsConcurrently(pocFiles, true, "") +} + +// loadExternalPocs 从外部路径加载POC +func loadExternalPocs(pocPath string) { + if !directoryExists(pocPath) { + Common.LogError(fmt.Sprintf("POC目录不存在: %s", pocPath)) + return + } + + // 收集所有POC文件路径 + var pocFiles []string + err := filepath.Walk(pocPath, func(path string, info os.FileInfo, err error) error { + if err != nil || info == nil || info.IsDir() { + return nil + } + + if isPocFile(info.Name()) { + pocFiles = append(pocFiles, path) + } + return nil + }) + + if err != nil { + Common.LogError(fmt.Sprintf("遍历POC目录失败: %v", err)) + return + } + + // 并发加载POC文件 + loadPocsConcurrently(pocFiles, false, pocPath) +} + +// loadPocsConcurrently 并发加载POC文件 +func loadPocsConcurrently(pocFiles []string, isEmbedded bool, pocPath string) { + pocCount := len(pocFiles) + if pocCount == 0 { + return + } + + var wg sync.WaitGroup + var mu sync.Mutex + var successCount, failCount int + + // 使用信号量控制并发数 + semaphore := make(chan struct{}, concurrencyLimit) + + for _, file := range pocFiles { + wg.Add(1) + semaphore <- struct{}{} // 获取信号量 + + go func(filename string) { + defer func() { + <-semaphore // 释放信号量 + wg.Done() + }() + + var poc *lib.Poc + var err error + + // 根据不同的来源加载POC + if isEmbedded { + poc, err = lib.LoadPoc(filename, pocsFS) + } else { + poc, err = lib.LoadPocbyPath(filename) + } + + mu.Lock() + defer mu.Unlock() + + if err != nil { + failCount++ + return + } + + if poc != nil { + allPocs = append(allPocs, poc) + successCount++ + } + }(file) + } + + wg.Wait() + Common.LogSuccess(fmt.Sprintf("POC加载完成: 总共%d个,成功%d个,失败%d个", + pocCount, successCount, failCount)) +} + +// directoryExists 检查目录是否存在 +func directoryExists(path string) bool { + info, err := os.Stat(path) + return err == nil && info.IsDir() +} + +// isPocFile 检查文件是否为POC文件 +func isPocFile(filename string) bool { + lowerName := strings.ToLower(filename) + return strings.HasSuffix(lowerName, yamlExt) || strings.HasSuffix(lowerName, ymlExt) +} + +// filterPocs 根据POC名称筛选 +func filterPocs(pocName string) []*lib.Poc { + if pocName == "" { + return allPocs + } + + // 转换为小写以进行不区分大小写的匹配 + searchName := strings.ToLower(pocName) + var matchedPocs []*lib.Poc - for _, poc := range AllPocs { - if strings.Contains(poc.Name, pocname) { + for _, poc := range allPocs { + if poc != nil && strings.Contains(strings.ToLower(poc.Name), searchName) { matchedPocs = append(matchedPocs, poc) } } - Common.LogDebug(fmt.Sprintf("POC筛选完成,匹配到 %d 个", len(matchedPocs))) + return matchedPocs } diff --git a/go.mod b/go.mod index b6898d3..a3a5320 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/tomatome/grdp v0.0.0-20211231062539-be8adab7eaf3 golang.org/x/crypto v0.33.0 golang.org/x/net v0.35.0 + golang.org/x/sync v0.11.0 golang.org/x/sys v0.30.0 golang.org/x/text v0.22.0 google.golang.org/genproto v0.0.0-20221027153422-115e99e71e1c