diff --git a/Plugins/SSH.go b/Plugins/SSH.go index e8417d0..040b601 100644 --- a/Plugins/SSH.go +++ b/Plugins/SSH.go @@ -7,6 +7,7 @@ import ( "io/ioutil" "net" "strings" + "sync" "time" ) @@ -15,31 +16,98 @@ func SshScan(info *Common.HostInfo) (tmperr error) { return } + threads := 10 // 设置线程数 + taskChan := make(chan struct { + user string + pass string + }, len(Common.Userdict["ssh"])*len(Common.Passwords)) + + // 创建结果通道 + resultChan := make(chan error, threads) + + // 生成所有任务 for _, user := range Common.Userdict["ssh"] { for _, pass := range Common.Passwords { pass = strings.Replace(pass, "{user}", user, -1) - success, err := SshConn(info, user, pass) + taskChan <- struct { + user string + pass string + }{user, pass} + } + } + close(taskChan) - if err != nil { - errlog := fmt.Sprintf("[-] SSH认证失败 %v:%v User:%v Pass:%v Err:%v", - info.Host, info.Ports, user, pass, err) - Common.LogError(errlog) - tmperr = err + // 启动工作线程 + var wg sync.WaitGroup + for i := 0; i < threads; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for task := range taskChan { + // 为每个任务创建结果通道 + done := make(chan struct { + success bool + err error + }) - if Common.CheckErrs(err) { - return err + // 执行SSH连接 + go func(user, pass string) { + success, err := SshConn(info, user, pass) + done <- struct { + success bool + err error + }{success, err} + }(task.user, task.pass) + + // 等待结果或超时 + var err error + select { + case result := <-done: + err = result.err + if result.success { + resultChan <- nil + return + } + case <-time.After(time.Duration(Common.Timeout) * time.Second): + err = fmt.Errorf("连接超时") + } + + if err != nil { + errlog := fmt.Sprintf("[-] SSH认证失败 %v:%v User:%v Pass:%v Err:%v", + info.Host, info.Ports, task.user, task.pass, err) + Common.LogError(errlog) + + if Common.CheckErrs(err) { + resultChan <- err + return + } + } + + if Common.SshKeyPath != "" { + resultChan <- err + return } } + resultChan <- nil + }() + } - if success { - return nil - } + // 等待所有线程完成 + go func() { + wg.Wait() + close(resultChan) + }() - if Common.SshKeyPath != "" { + // 检查结果 + for err := range resultChan { + if err != nil { + tmperr = err + if Common.CheckErrs(err) { return err } } } + return tmperr } @@ -66,7 +134,7 @@ func SshConn(info *Common.HostInfo, user string, pass string) (flag bool, err er HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error { return nil }, - Timeout: time.Duration(Common.Timeout), + Timeout: time.Duration(Common.Timeout) * time.Millisecond, } client, err := ssh.Dial("tcp", fmt.Sprintf("%v:%v", info.Host, info.Ports), config)