From 102d100c252ba4ce42a3fa2710d46b33d23071c5 Mon Sep 17 00:00:00 2001 From: ZacharyZcR <2903735704@qq.com> Date: Fri, 7 Feb 2025 11:39:04 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E4=BC=98=E5=8C=96=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Core/ICMP.go | 10 ++ Core/Scanner.go | 439 ++++++++++++++++++++++++++---------------------- 2 files changed, 245 insertions(+), 204 deletions(-) diff --git a/Core/ICMP.go b/Core/ICMP.go index 3cfe3bd..2535e12 100644 --- a/Core/ICMP.go +++ b/Core/ICMP.go @@ -45,6 +45,16 @@ func CheckLive(hostslist []string, Ping bool) []string { return AliveHosts } +// IsContain 检查切片中是否包含指定元素 +func IsContain(items []string, item string) bool { + for _, eachItem := range items { + if eachItem == item { + return true + } + } + return false +} + func handleAliveHosts(chanHosts chan string, hostslist []string, isPing bool) { for ip := range chanHosts { if _, ok := ExistHosts[ip]; !ok && IsContain(hostslist, ip) { diff --git a/Core/Scanner.go b/Core/Scanner.go index 8e1b7c1..17d15ca 100644 --- a/Core/Scanner.go +++ b/Core/Scanner.go @@ -29,103 +29,121 @@ func Scan(info Common.HostInfo) { ch := make(chan struct{}, Common.ThreadNum) wg := sync.WaitGroup{} - // 根据不同模式执行扫描 + // 执行扫描逻辑 switch { case Common.LocalMode: - // 本地信息收集模式 - LocalScan = true - - // 定义本地模式允许的插件 - validLocalPlugins := make(map[string]bool) - for _, plugin := range Common.PluginGroups[Common.ModeLocal] { - validLocalPlugins[plugin] = true - } - - // 如果没有指定扫描模式或为默认的All,设置为 ModeLocal - if Common.ScanMode == "" || Common.ScanMode == "All" { - Common.ScanMode = Common.ModeLocal - } else if Common.ScanMode != Common.ModeLocal { - // 不是完整模式时,检查是否是合法的单插件 - if !validLocalPlugins[Common.ScanMode] { - Common.LogError(fmt.Sprintf("无效的本地模式插件: %s, 仅支持 localinfo", Common.ScanMode)) - return - } - } - - if Common.ScanMode == Common.ModeLocal { - Common.LogInfo("执行本地信息收集 - 使用全部本地插件") - } else { - Common.LogInfo(fmt.Sprintf("执行本地信息收集 - 使用插件: %s", Common.ScanMode)) - } - - executeScans([]Common.HostInfo{info}, &ch, &wg) - + executeLocalScan(info, &ch, &wg) case len(Common.URLs) > 0: - // Web模式 - WebScan = true - - // 从 pluginGroups 获取Web模式允许的插件 - validWebPlugins := make(map[string]bool) - for _, plugin := range Common.PluginGroups[Common.ModeWeb] { - validWebPlugins[plugin] = true - } - - // 如果没有指定扫描模式,默认设置为 ModeWeb - if Common.ScanMode == "" || Common.ScanMode == "All" { - Common.ScanMode = Common.ModeWeb - } - - // 如果不是 ModeWeb,检查是否是合法的单插件 - if Common.ScanMode != Common.ModeWeb { - if !validWebPlugins[Common.ScanMode] { - Common.LogError(fmt.Sprintf("无效的Web插件: %s, 仅支持 webtitle 和 webpoc", Common.ScanMode)) - return - } - // ScanMode 保持为单插件名 - } - - var targetInfos []Common.HostInfo - for _, url := range Common.URLs { - urlInfo := info - if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { - url = "http://" + url - } - urlInfo.Url = url - targetInfos = append(targetInfos, urlInfo) - } - - if Common.ScanMode == Common.ModeWeb { - Common.LogInfo("开始Web扫描 - 使用全部Web插件") - } else { - Common.LogInfo(fmt.Sprintf("开始Web扫描 - 使用插件: %s", Common.ScanMode)) - } - executeScans(targetInfos, &ch, &wg) - + executeWebScan(info, &ch, &wg) default: - // 主机扫描模式 - if info.Host == "" { - Common.LogError("未指定扫描目标") - return - } - - hosts, err := Common.ParseIP(info.Host, Common.HostsFile, Common.ExcludeHosts) - if err != nil { - Common.LogError(fmt.Sprintf("解析主机错误: %v", err)) - return - } - Common.LogInfo("开始主机扫描") - executeScan(hosts, info, &ch, &wg) + executeHostScan(info, &ch, &wg) } + // 等待扫描完成 finishScan(&wg) } +// 执行本地扫描 +func executeLocalScan(info Common.HostInfo, ch *chan struct{}, wg *sync.WaitGroup) { + Common.LogInfo("执行本地信息收集") + + // 定义本地模式允许的插件 + validLocalPlugins := getValidPlugins(Common.ModeLocal) + + // 校验扫描模式 + if err := validateScanMode(validLocalPlugins, Common.ModeLocal); err != nil { + Common.LogError(err.Error()) + return + } + + if Common.ScanMode == Common.ModeLocal { + Common.LogInfo("使用全部本地插件") + } else { + Common.LogInfo(fmt.Sprintf("使用插件: %s", Common.ScanMode)) + } + + // 执行扫描 + executeScans([]Common.HostInfo{info}, ch, wg) +} + +// 执行Web扫描 +func executeWebScan(info Common.HostInfo, ch *chan struct{}, wg *sync.WaitGroup) { + Common.LogInfo("开始Web扫描") + + // 从 pluginGroups 获取Web模式允许的插件 + validWebPlugins := getValidPlugins(Common.ModeWeb) + + // 校验扫描模式 + if err := validateScanMode(validWebPlugins, Common.ModeWeb); err != nil { + Common.LogError(err.Error()) + return + } + + // 创建目标URL信息 + var targetInfos []Common.HostInfo + for _, url := range Common.URLs { + urlInfo := info + if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") { + url = "http://" + url + } + urlInfo.Url = url + targetInfos = append(targetInfos, urlInfo) + } + + if Common.ScanMode == Common.ModeWeb { + Common.LogInfo("使用全部Web插件") + } else { + Common.LogInfo(fmt.Sprintf("使用插件: %s", Common.ScanMode)) + } + + // 执行扫描 + executeScans(targetInfos, ch, wg) +} + +// 执行主机扫描 +func executeHostScan(info Common.HostInfo, ch *chan struct{}, wg *sync.WaitGroup) { + if info.Host == "" { + Common.LogError("未指定扫描目标") + return + } + + hosts, err := Common.ParseIP(info.Host, Common.HostsFile, Common.ExcludeHosts) + if err != nil { + Common.LogError(fmt.Sprintf("解析主机错误: %v", err)) + return + } + + Common.LogInfo("开始主机扫描") + executeScan(hosts, info, ch, wg) +} + +// 获取合法的插件列表 +func getValidPlugins(mode string) map[string]bool { + validPlugins := make(map[string]bool) + for _, plugin := range Common.PluginGroups[mode] { + validPlugins[plugin] = true + } + return validPlugins +} + +// 校验扫描模式是否有效 +func validateScanMode(validPlugins map[string]bool, mode string) error { + if Common.ScanMode == "" || Common.ScanMode == "All" { + Common.ScanMode = mode + } else if _, exists := validPlugins[Common.ScanMode]; !exists { + return fmt.Errorf("无效的%s插件: %s", mode, Common.ScanMode) + } + return nil +} + // executeScan 执行主扫描流程 func executeScan(hosts []string, info Common.HostInfo, ch *chan struct{}, wg *sync.WaitGroup) { var targetInfos []Common.HostInfo + // 扫描主机和端口 if len(hosts) > 0 || len(Common.HostPort) > 0 { - if (Common.DisablePing == false && len(hosts) > 1) || Common.IsICMPScan() { + // 处理活跃主机 + if shouldPingScan(hosts) { hosts = CheckLive(hosts, Common.UsePing) Common.LogInfo(fmt.Sprintf("存活主机数量: %d", len(hosts))) if Common.IsICMPScan() { @@ -133,39 +151,62 @@ func executeScan(hosts []string, info Common.HostInfo, ch *chan struct{}, wg *sy } } - var alivePorts []string - if Common.IsWebScan() { - alivePorts = NoPortScan(hosts, Common.Ports) - } else if len(hosts) > 0 { - alivePorts = PortScan(hosts, Common.Ports, Common.Timeout) - Common.LogInfo(fmt.Sprintf("存活端口数量: %d", len(alivePorts))) - if Common.IsPortScan() { - return - } + // 处理活跃端口 + alivePorts := getAlivePorts(hosts) + if len(alivePorts) > 0 { + targetInfos = prepareTargetInfos(alivePorts, info) } - - if len(Common.HostPort) > 0 { - alivePorts = append(alivePorts, Common.HostPort...) - alivePorts = Common.RemoveDuplicate(alivePorts) - Common.HostPort = nil - Common.LogInfo(fmt.Sprintf("存活端口数量: %d", len(alivePorts))) - } - - targetInfos = prepareTargetInfos(alivePorts, info) } - for _, url := range Common.URLs { - urlInfo := info - urlInfo.Url = url - targetInfos = append(targetInfos, urlInfo) - } + // 添加 URL 扫描目标 + targetInfos = appendURLTargets(targetInfos, info) + // 如果有扫描目标,执行漏洞扫描 if len(targetInfos) > 0 { Common.LogInfo("开始漏洞扫描") executeScans(targetInfos, ch, wg) } } +// shouldPingScan 判断是否需要进行 ping 扫描 +func shouldPingScan(hosts []string) bool { + return (Common.DisablePing == false && len(hosts) > 1) || Common.IsICMPScan() +} + +// getAlivePorts 获取存活端口 +func getAlivePorts(hosts []string) []string { + var alivePorts []string + if Common.IsWebScan() { + alivePorts = NoPortScan(hosts, Common.Ports) + } else if len(hosts) > 0 { + alivePorts = PortScan(hosts, Common.Ports, Common.Timeout) + Common.LogInfo(fmt.Sprintf("存活端口数量: %d", len(alivePorts))) + if Common.IsPortScan() { + return nil // 结束扫描 + } + } + + // 合并传入的端口信息 + if len(Common.HostPort) > 0 { + alivePorts = append(alivePorts, Common.HostPort...) + alivePorts = Common.RemoveDuplicate(alivePorts) + Common.HostPort = nil + Common.LogInfo(fmt.Sprintf("存活端口数量: %d", len(alivePorts))) + } + + return alivePorts +} + +// appendURLTargets 添加 URL 扫描目标 +func appendURLTargets(targetInfos []Common.HostInfo, baseInfo Common.HostInfo) []Common.HostInfo { + for _, url := range Common.URLs { + urlInfo := baseInfo + urlInfo.Url = url + targetInfos = append(targetInfos, urlInfo) + } + return targetInfos +} + // prepareTargetInfos 准备扫描目标信息 func prepareTargetInfos(alivePorts []string, baseInfo Common.HostInfo) []Common.HostInfo { var infos []Common.HostInfo @@ -183,8 +224,58 @@ func prepareTargetInfos(alivePorts []string, baseInfo Common.HostInfo) []Common. return infos } +// 扫描任务结构体定义 +type ScanTask struct { + pluginName string + target Common.HostInfo +} + +// executeScans 执行扫描任务 func executeScans(targets []Common.HostInfo, ch *chan struct{}, wg *sync.WaitGroup) { mode := Common.GetScanMode() + + // 获取待执行的插件列表 + pluginsToRun, isSinglePlugin := getPluginsToRun(mode) + + var tasks []ScanTask + actualTasks := 0 + loadedPlugins := make([]string, 0) + + // 遍历目标,收集任务 + for _, target := range targets { + targetPort, _ := strconv.Atoi(target.Ports) + for _, pluginName := range pluginsToRun { + plugin, exists := Common.PluginManager[pluginName] + if !exists { + continue + } + + taskAdded, newTasks := collectScanTasks(plugin, target, targetPort, pluginName, isSinglePlugin) + if taskAdded { + actualTasks += len(newTasks) + loadedPlugins = append(loadedPlugins, pluginName) + tasks = append(tasks, newTasks...) + } + } + } + + // 去重并排序插件 + finalPlugins := getUniquePlugins(loadedPlugins) + + // 输出加载的插件信息 + Common.LogInfo(fmt.Sprintf("加载的插件: %s", strings.Join(finalPlugins, ", "))) + + // 初始化进度条 + initializeProgressBar(actualTasks) + + // 执行收集的任务 + for _, task := range tasks { + AddScan(task.pluginName, task.target, ch, wg) + } +} + +// 获取待执行插件列表 +func getPluginsToRun(mode string) ([]string, bool) { var pluginsToRun []string isSinglePlugin := false @@ -195,82 +286,28 @@ func executeScans(targets []Common.HostInfo, ch *chan struct{}, wg *sync.WaitGro isSinglePlugin = true } - loadedPlugins := make([]string, 0) - actualTasks := 0 + return pluginsToRun, isSinglePlugin +} - type ScanTask struct { - pluginName string - target Common.HostInfo - } - tasks := make([]ScanTask, 0) +// 收集扫描任务 +func collectScanTasks(plugin Common.ScanPlugin, target Common.HostInfo, targetPort int, pluginName string, isSinglePlugin bool) (bool, []ScanTask) { + var tasks []ScanTask + taskAdded := false - // 第一次遍历:计算任务数和收集要执行的插件 - for _, target := range targets { - targetPort, _ := strconv.Atoi(target.Ports) - - for _, pluginName := range pluginsToRun { - plugin, exists := Common.PluginManager[pluginName] - if !exists { - continue - } - - // Web模式特殊处理 - if WebScan { - actualTasks++ - loadedPlugins = append(loadedPlugins, pluginName) - tasks = append(tasks, ScanTask{ - pluginName: pluginName, - target: target, - }) - continue - } - - // 本地扫描模式 - if LocalScan { - if len(plugin.Ports) == 0 { - actualTasks++ - loadedPlugins = append(loadedPlugins, pluginName) - tasks = append(tasks, ScanTask{ - pluginName: pluginName, - target: target, - }) - } - continue - } - - // 单插件模式 - if isSinglePlugin { - actualTasks++ - loadedPlugins = append(loadedPlugins, pluginName) - tasks = append(tasks, ScanTask{ - pluginName: pluginName, - target: target, - }) - continue - } - - // 常规模式 - if len(plugin.Ports) > 0 { - if plugin.HasPort(targetPort) { - actualTasks++ - loadedPlugins = append(loadedPlugins, pluginName) - tasks = append(tasks, ScanTask{ - pluginName: pluginName, - target: target, - }) - } - } else { - actualTasks++ - loadedPlugins = append(loadedPlugins, pluginName) - tasks = append(tasks, ScanTask{ - pluginName: pluginName, - target: target, - }) - } - } + // Web模式特殊处理 + if WebScan || LocalScan || isSinglePlugin || len(plugin.Ports) == 0 || plugin.HasPort(targetPort) { + taskAdded = true + tasks = append(tasks, ScanTask{ + pluginName: pluginName, + target: target, + }) } - // 去重并输出实际加载的插件 + return taskAdded, tasks +} + +// 获取去重后的插件列表 +func getUniquePlugins(loadedPlugins []string) []string { uniquePlugins := make(map[string]struct{}) for _, p := range loadedPlugins { uniquePlugins[p] = struct{}{} @@ -280,11 +317,13 @@ func executeScans(targets []Common.HostInfo, ch *chan struct{}, wg *sync.WaitGro for p := range uniquePlugins { finalPlugins = append(finalPlugins, p) } + sort.Strings(finalPlugins) + return finalPlugins +} - Common.LogInfo(fmt.Sprintf("加载的插件: %s", strings.Join(finalPlugins, ", "))) - - // 初始化进度条 +// 初始化进度条 +func initializeProgressBar(actualTasks int) { if Common.ShowProgress { Common.ProgressBar = progressbar.NewOptions(actualTasks, progressbar.OptionEnableColorCodes(true), @@ -303,11 +342,6 @@ func executeScans(targets []Common.HostInfo, ch *chan struct{}, wg *sync.WaitGro progressbar.OptionSetRenderBlankState(true), ) } - - // 执行收集的任务 - for _, task := range tasks { - AddScan(task.pluginName, task.target, ch, wg) - } } // finishScan 完成扫描任务 @@ -324,7 +358,7 @@ func finishScan(wg *sync.WaitGroup) { // Mutex用于保护共享资源的并发访问 var Mutex = &sync.Mutex{} -// AddScan +// AddScan 添加扫描任务并启动扫描 func AddScan(plugin string, info Common.HostInfo, ch *chan struct{}, wg *sync.WaitGroup) { *ch <- struct{}{} wg.Add(1) @@ -335,20 +369,14 @@ func AddScan(plugin string, info Common.HostInfo, ch *chan struct{}, wg *sync.Wa <-*ch }() - Mutex.Lock() + // 使用原子操作更新扫描计数 atomic.AddInt64(&Common.Num, 1) - Mutex.Unlock() + // 执行扫描插件 ScanFunc(&plugin, &info) - Common.OutputMutex.Lock() - atomic.AddInt64(&Common.End, 1) - if Common.ProgressBar != nil { - // 清除当前行 - fmt.Print("\033[2K\r") - Common.ProgressBar.Add(1) - } - Common.OutputMutex.Unlock() + // 更新扫描结束后的状态 + updateScanProgress(&info) }() } @@ -371,12 +399,15 @@ func ScanFunc(name *string, info *Common.HostInfo) { } } -// IsContain 检查切片中是否包含指定元素 -func IsContain(items []string, item string) bool { - for _, eachItem := range items { - if eachItem == item { - return true - } +// updateScanProgress 更新扫描进度 +func updateScanProgress(info *Common.HostInfo) { + // 输出互斥锁更新进度条 + Common.OutputMutex.Lock() + atomic.AddInt64(&Common.End, 1) + if Common.ProgressBar != nil { + // 清除当前行并更新进度条 + fmt.Print("\033[2K\r") + Common.ProgressBar.Add(1) } - return false + Common.OutputMutex.Unlock() }