feat: refactor

master
Dave S. 2 years ago
parent 68fee6bbb7
commit ac9935a4e0
  1. 2
      bigint.go
  2. 383
      config.go
  3. 60
      conn.go
  4. 12
      crypt.go
  5. 348
      endpoint.go
  6. 154
      eparse.go
  7. 31
      legacy-winbox.go
  8. 28
      log.go
  9. 3
      math.go
  10. 8
      mtbf.go
  11. 51
      results.go
  12. 3
      service.go
  13. 328
      source.go
  14. 130
      task.go
  15. 130
      thread.go
  16. 15
      winbox.go

@ -1,5 +1,7 @@
package main package main
// bigint.go: methods for bigint operation chaining
import ( import (
"math/big" "math/big"
) )

@ -8,170 +8,143 @@ import (
"time" "time"
) )
// configParameterOptions represents additional options for a configParameter.
type configParameterOptions struct {
sw, hidden, command bool
callback func()
}
// configParameter represents a single configuration parameter.
type configParameter struct {
name string // duplicated in configMap, but also saved here for convenience
value, def any // value and default value
description string // description for this parameter
parsed bool // true if it was successfully parsed from commandline
configParameterOptions
}
type configParameterTypeUnion = interface {
bool | int | uint | float64 | string | []int | []uint | []float64 | []string | []bool
}
var configMap = map[string]configParameter{} var configMap = map[string]configParameter{}
var configAliasMap = map[string]string{} var configAliasMap = map[string]string{}
var configParsingFinished = false func init() {
registerCommand("help", "show program usage", func() {
// --- log("", 0, "options:")
// registration
func registerConfigParameter[T configParameterTypeUnion](name string, def T, description string, opts configParameterOptions) {
name = strings.ToLower(name)
_, ok := configMap[name]
failIf(ok, "cannot register config parameter (already exists): \"%v\"", name)
_, ok = configAliasMap[name] parms := make([]string, 0, len(configMap))
failIf(ok, "cannot register config parameter (already exists as an alias): \"%v\"", name) for key := range configMap {
parms = append(parms, key)
}
sort.Strings(parms)
failIf(opts.command && opts.callback == nil, "\"%v\" is defined as a command but callback is missing", name) for _, parmName := range parms {
parm := configMap[parmName]
configMap[name] = configParameter{name: name, value: def, def: def, description: description, if parm.hidden {
configParameterOptions: opts} continue
} }
func registerParam[T configParameterTypeUnion](name string, def T, description string) { header := "-" + parm.name
registerConfigParameter(name, def, description, configParameterOptions{}) if len(parm.name) > 1 {
header = "-" + header
} }
func registerParamEx[T configParameterTypeUnion](name string, def T, description string, options configParameterOptions) { aliases := []string{}
registerConfigParameter(name, def, description, options) for alias, target := range configAliasMap {
if target == parm.name {
if len(alias) == 1 {
aliases = append(aliases, "-"+alias)
} else {
aliases = append(aliases, "--"+alias)
} }
func registerParamHidden[T configParameterTypeUnion](name string, def T) {
registerConfigParameter(name, def, "", configParameterOptions{hidden: true})
} }
func registerParamWithCallback[T configParameterTypeUnion](name string, def T, description string, callback func()) {
registerConfigParameter(name, def, description, configParameterOptions{callback: callback})
} }
func registerCommand(name string, description string, callback func()) { if len(aliases) > 0 {
registerConfigParameter(name, false, description, configParameterOptions{command: true, callback: callback}) sort.Strings(aliases)
header = header + " (aliases: " + strings.Join(aliases, ", ") + ")"
} }
func registerSwitch(name string, description string) { header = header + ":"
registerConfigParameter(name, false, description, configParameterOptions{sw: true}) description := " (description missing)"
if parm.description != "" {
description = " " + parm.description
} }
func registerAlias(alias, target string) { if parm.command || parm.sw {
alias, target = strings.ToLower(alias), strings.ToLower(target) log("", 0, "%s\n%s", header, description)
} else {
_, ok := configAliasMap[alias] log("", 0, "%s\n%s\n default: %v", header, description, parm.value)
failIf(ok, "cannot register alias (already exists): \"%v\"", alias) }
_, ok = configMap[alias]
failIf(ok, "cannot register alias (already exists as a config parameter): \"%v\"", alias)
_, ok = configMap[target]
failIf(!ok, "cannot register alias \"%v\": target \"%v\" does not exist", alias, target)
configAliasMap[alias] = target
} }
// --- log("", 0, "")
// acquisition log("", 0, "examples:")
log("", 0, " single target:")
func getParamGeneric(name string) any { log("", 0, " ./mtbf --ip 127.0.0.1 --port 8291 --login admin --password 12345678 --out-file good.txt")
name = strings.ToLower(name) log("", 0, " multiple targets with multiple passwords:")
log("", 0, " ./mtbf --ip-list ips.txt --port 8291 --login admin --password-list passwords.txt --out-file good.txt")
parm, ok := configMap[name] os.Exit(0)
failIf(!ok, "unknown config parameter: \"%v\"", name) })
failIf(parm.command, "config parameter \"%v\" is a command", name)
if parm.sw { registerAlias("?", "help")
return parm.parsed // switches always return true if they were parsed registerAlias("h", "help")
} else if parm.parsed || parm.hidden {
return parm.value // parsed and hidden parms return their current value
} else {
return parm.def // otherwise, use default value
}
} }
func getParam[T configParameterTypeUnion](name string) T { // configParameterOptions represents additional options for a configParameter.
return getParamGeneric(name).(T) type configParameterOptions struct {
sw, hidden, command bool
callback func()
} }
func getParamInt(name string) int { // configParameter represents a single configuration parameter.
return getParam[int](name) type configParameter struct {
name string // duplicated in configMap, but also saved here for convenience
value, def any // value and default value
description string // description for this parameter
parsed bool // true if it was successfully parsed from commandline
configParameterOptions
} }
func getParamFloat(name string) float64 { type configParameterTypeUnion = interface {
return getParam[float64](name) bool | int | uint | float64 | string | []int | []uint | []float64 | []string | []bool | map[string]bool
} }
func getParamIntSlice(name string) []int { // --------------
return getParam[[]int](name) // parsing
} // --------------
func getParamBool(name string) bool { func parseAppConfig() {
return getParam[bool](name) log("cfg", 1, "parsing config")
}
func getParamSwitch(name string) bool { totalFinalized := 0
return getParamBool(name)
}
func getParamString(name string) string { for i := 1; i < len(os.Args); i++ {
return getParam[string](name) arg := getCmdlineParm(i)
if len(arg) == 0 {
continue
} }
func getParamStringSlice(name string) []string { failIf(arg[0] != '-', "\"%v\" is not a commandline parameter", arg)
return getParam[[]string](name) arg = strings.TrimPrefix(arg, "-")
} arg = strings.TrimPrefix(arg, "-")
func getParamDurationMS(name string) time.Duration { failIf(len(arg) == 0, "\"%v\" is not a commandline parameter", getCmdlineParm(i))
tm := getParam[int](name)
failIf(tm < -1, "\"%v\" can only be set to -1 or a positive value", name)
if tm == -1 {
tm = 0
}
return time.Duration(tm) * time.Millisecond parm, ok := configMap[strings.ToLower(arg)]
} if !ok {
alias, ok := configAliasMap[strings.ToLower(arg)]
failIf(!ok, "unknown commandline parameter: \"%v\"", arg)
// --- parm, ok = configMap[alias]
// setting failIf(!ok, "alias \"%v\" references unknown commandline parameter", arg)
func setParam(name string, value any) { log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name)
name = strings.ToLower(name) }
parm, ok := configMap[name] failIf(parm.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i))
failIf(!ok, "unknown config parameter: \"%v\"", name) failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name)
failIf(parm.hidden, "config parameter \"%v\" is hidden and cannot be set", name)
failIf(parm.command, "config parameter \"%v\" is a command and cannot be set", name)
failIf(parm.sw && !value.(bool), "config parameter \"%v\" is a switch and only accepts boolean arguments", name)
parm.value = value if !parm.command {
if parm.callback != nil { if parm.sw {
parm.callback() parm.writeParmValue("true")
} else {
i++
parm.writeParmValue(getCmdlineParm(i))
}
} }
configMap[name] = parm parm.finalize()
totalFinalized++
} }
// --- log("cfg", 1, "parsed %v commandline parameters", totalFinalized)
// parsing }
// getCmdlineParm retrieves a commandline parameter with index i. // getCmdlineParm retrieves a commandline parameter with index i.
func getCmdlineParm(i int) string { func getCmdlineParm(i int) string {
@ -238,117 +211,139 @@ func (parm *configParameter) finalize() {
log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value) log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value)
} }
func parseAppConfig() { // --------------
log("cfg", 1, "parsing config") // registration
// --------------
totalFinalized := 0
for i := 1; i < len(os.Args); i++ { func registerConfigParameter[T configParameterTypeUnion](name string, def T, description string, opts configParameterOptions) {
arg := getCmdlineParm(i) name = strings.ToLower(name)
if len(arg) == 0 {
continue
}
failIf(arg[0] != '-', "\"%v\" is not a commandline parameter", arg) _, ok := configMap[name]
arg = strings.TrimPrefix(arg, "-") failIf(ok, "cannot register config parameter (already exists): \"%v\"", name)
arg = strings.TrimPrefix(arg, "-")
failIf(len(arg) == 0, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) _, ok = configAliasMap[name]
failIf(ok, "cannot register config parameter (already exists as an alias): \"%v\"", name)
parm, ok := configMap[strings.ToLower(arg)] failIf(opts.command && opts.callback == nil, "\"%v\" is defined as a command but callback is missing", name)
if !ok {
alias, ok := configAliasMap[strings.ToLower(arg)]
failIf(!ok, "unknown commandline parameter: \"%v\"", arg)
parm, ok = configMap[alias] configMap[name] = configParameter{name: name, value: def, def: def, description: description,
failIf(!ok, "alias \"%v\" references unknown commandline parameter", arg) configParameterOptions: opts}
}
log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name) func registerParam[T configParameterTypeUnion](name string, def T, description string) {
registerConfigParameter(name, def, description, configParameterOptions{})
} }
failIf(parm.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) func registerParamHidden[T configParameterTypeUnion](name string, def T) {
failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name) registerConfigParameter(name, def, "", configParameterOptions{hidden: true})
}
if !parm.command { func registerParamWithCallback[T configParameterTypeUnion](name string, def T, description string, callback func()) {
if parm.sw { registerConfigParameter(name, def, description, configParameterOptions{callback: callback})
parm.writeParmValue("true")
} else {
i++
parm.writeParmValue(getCmdlineParm(i))
} }
func registerCommand(name string, description string, callback func()) {
registerConfigParameter(name, false, description, configParameterOptions{command: true, callback: callback})
} }
parm.finalize() func registerSwitch(name string, description string) {
totalFinalized++ registerConfigParameter(name, false, description, configParameterOptions{sw: true})
} }
log("cfg", 1, "parsed %v commandline parameters", totalFinalized) func registerAlias(alias, target string) {
configParsingFinished = true alias, target = strings.ToLower(alias), strings.ToLower(target)
_, ok := configAliasMap[alias]
failIf(ok, "cannot register alias (already exists): \"%v\"", alias)
_, ok = configMap[alias]
failIf(ok, "cannot register alias (already exists as a config parameter): \"%v\"", alias)
_, ok = configMap[target]
failIf(!ok, "cannot register alias \"%v\": target \"%v\" does not exist", alias, target)
configAliasMap[alias] = target
} }
func showHelp() { // --------------
log("", 0, "options:") // acquisition
// --------------
parms := make([]string, 0, len(configMap)) func getParamGeneric(name string) any {
for key := range configMap { name = strings.ToLower(name)
parms = append(parms, key)
parm, ok := configMap[name]
failIf(!ok, "unknown config parameter: \"%v\"", name)
failIf(parm.command, "config parameter \"%v\" is a command", name)
if parm.sw {
return parm.parsed // switches always return true if they were parsed
} else if parm.parsed || parm.hidden {
return parm.value // parsed and hidden parms return their current value
} else {
return parm.def // otherwise, use default value
}
} }
sort.Strings(parms)
for _, parmName := range parms { func getParam[T configParameterTypeUnion](name string) T {
parm := configMap[parmName] return getParamGeneric(name).(T)
}
if parm.hidden { func getParamInt(name string) int {
continue return getParam[int](name)
} }
header := "-" + parm.name func getParamFloat(name string) float64 {
if len(parm.name) > 1 { return getParam[float64](name)
header = "-" + header
} }
aliases := []string{} func getParamIntSlice(name string) []int {
for alias, target := range configAliasMap { return getParam[[]int](name)
if target == parm.name {
if len(alias) == 1 {
aliases = append(aliases, "-"+alias)
} else {
aliases = append(aliases, "--"+alias)
} }
break
func getParamBool(name string) bool {
return getParam[bool](name)
} }
func getParamSwitch(name string) bool {
return getParamBool(name)
} }
if len(aliases) > 0 { func getParamString(name string) string {
sort.Strings(aliases) return getParam[string](name)
header = header + " (aliases: " + strings.Join(aliases, ", ") + ")"
} }
header = header + ":" func getParamStringSlice(name string) []string {
description := " (description missing)" return getParam[[]string](name)
if parm.description != "" {
description = " " + parm.description
} }
if parm.command || parm.sw { func getParamDurationMS(name string) time.Duration {
log("", 0, "%s\n%s", header, description) tm := getParam[int](name)
} else { failIf(tm < -1, "\"%v\" can only be set to -1 or a positive value", name)
log("", 0, "%s\n%s\n default: %v", header, description, parm.value) if tm == -1 {
tm = 0
} }
return time.Duration(tm) * time.Millisecond
} }
log("", 0, "") // --------------
log("", 0, "examples:") // setting
log("", 0, " single target:") // --------------
log("", 0, " ./mtbf --ip 127.0.0.1 --port 8291 --login admin --password 12345678 --out-file good.txt")
log("", 0, " multiple targets with multiple passwords:")
log("", 0, " ./mtbf --ip-list ips.txt --port 8291 --login admin --password-list passwords.txt --out-file good.txt")
os.Exit(0) func setParam(name string, value any) {
name = strings.ToLower(name)
parm, ok := configMap[name]
failIf(!ok, "unknown config parameter: \"%v\"", name)
failIf(parm.hidden, "config parameter \"%v\" is hidden and cannot be set", name)
failIf(parm.command, "config parameter \"%v\" is a command and cannot be set", name)
failIf(parm.sw && !value.(bool), "config parameter \"%v\" is a switch and only accepts boolean arguments", name)
parm.value = value
if parm.callback != nil {
parm.callback()
} }
func init() { configMap[name] = parm
registerCommand("help", "show program usage", showHelp)
registerAlias("?", "help")
registerAlias("h", "help")
} }

@ -10,17 +10,37 @@ type Connection struct {
dialer net.Dialer dialer net.Dialer
socket net.Conn socket net.Conn
connectTimeout time.Duration connectTimeout time.Duration
readTimeout time.Duration sendTimeout time.Duration
recvTimeout time.Duration
protocol string protocol string
} }
// NewConnection creates a Connection object. func init() {
func NewConnection() *Connection { registerParam("connect-timeout-ms", 3000, "")
registerParam("send-timeout-ms", 2000, "")
registerParam("recv-timeout-ms", 1000, "")
}
// NewConnection creates a Connection object and optionally connects to an Endpoint.
func NewConnection(endpoint *Endpoint) (*Connection, error) {
conn := Connection{} conn := Connection{}
conn.connectTimeout = getParamDurationMS("connect-timeout-ms") conn.connectTimeout = getParamDurationMS("connect-timeout-ms")
conn.readTimeout = getParamDurationMS("read-timeout-ms") conn.sendTimeout = getParamDurationMS("send-timeout-ms")
conn.recvTimeout = getParamDurationMS("recv-timeout-ms")
conn.protocol = "tcp" conn.protocol = "tcp"
return &conn
if endpoint != nil {
return &conn, conn.Connect(endpoint)
} else {
return &conn, nil
}
}
func (conn *Connection) Close() {
if conn.socket != nil {
conn.socket.Close()
conn.socket = nil
}
} }
// Connect initiates a connection to an Endpoint. // Connect initiates a connection to an Endpoint.
@ -31,7 +51,7 @@ func (conn *Connection) Connect(endpoint *Endpoint) (err error) {
log("conn", 2, "cannot connect to \"%v\": %v", endpoint, err.Error()) log("conn", 2, "cannot connect to \"%v\": %v", endpoint, err.Error())
} }
return err return
} }
// SetConnectTimeout sets a custom connect timeout on a Connection. // SetConnectTimeout sets a custom connect timeout on a Connection.
@ -39,12 +59,28 @@ func (conn *Connection) SetConnectTimeout(timeout time.Duration) {
conn.connectTimeout = timeout conn.connectTimeout = timeout
} }
// SetReadTimeout sets a custom read timeout on a Connection. // Send writes data to a Connection.
func (conn *Connection) SetReadTimeout(timeout time.Duration) { func (conn *Connection) Send(data []byte) (err error) {
conn.readTimeout = timeout if len(data) == 0 {
log("conn", 1, "tried to send empty buffer to a socket, ignoring")
return nil
} }
// Send writes data to a Connection. conn.socket.SetWriteDeadline(time.Now().Add(conn.sendTimeout))
func (conn *Connection) Send(data []byte) { _, err = conn.socket.Write(data)
conn.socket.SetReadDeadline(time.Now().Add(conn.readTimeout)) return
}
// Recv receives data from a Connection.
func (conn *Connection) Recv() (data []byte, err error) {
conn.socket.SetReadDeadline(time.Now().Add(conn.recvTimeout))
data = make([]byte, 1024)
n, err := conn.socket.Read(data)
if err != nil {
return nil, err
}
data = data[:n]
return
} }

@ -1,5 +1,7 @@
package main package main
// crypt.go: various cryptographical operations
import ( import (
"crypto/hmac" "crypto/hmac"
cryptoRand "crypto/rand" cryptoRand "crypto/rand"
@ -9,6 +11,11 @@ import (
"strings" "strings"
) )
func init() {
registerSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng")
mathRand.Seed(300)
}
func getSHA1Digest(data []byte) []byte { func getSHA1Digest(data []byte) []byte {
array := sha1.Sum(data) array := sha1.Sum(data)
return array[:] return array[:]
@ -96,8 +103,3 @@ func genRandomBytes(n int) ([]byte, error) {
return b, nil return b, nil
} }
func init() {
registerSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng")
mathRand.Seed(300)
}

@ -2,14 +2,74 @@ package main
import ( import (
"container/list" "container/list"
"errors"
"net"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
) )
func init() {
endpoints = list.New()
delayedEndpoints = list.New()
registerParam("port", []int{8291}, "one or more default ports")
registerParam("max-aps", 5, "maximum number of attempts per second for an endpoint")
registerSwitch("no-ipv6", "skip IPv6 entries")
registerSwitch("append-default-ports", "always append default ports even for targets in host:port format")
registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets")
registerSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found")
registerParam("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value")
registerParam("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection")
registerParam("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections")
registerParam("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections")
registerParam("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value")
registerParam("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors")
registerParam("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors")
registerParam("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value")
registerParam("max-read-errors", 20, "always remove endpoint after this many consecutive read errors")
registerParam("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors")
registerParam("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond")
registerParam("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error")
registerParam("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error")
}
// FetchEndpoint retrieves an endpoint: first, a delayed list is queried,
// then, if nothing is found, a normal list is searched.
// If all endpoints are delayed, a wait time is returned.
func FetchEndpoint() (e *Endpoint, waitTime time.Duration) {
globalEndpointMutex.Lock()
defer globalEndpointMutex.Unlock()
log("ep", 4, "fetching an endpoint")
e, waitTime = GetDelayedEndpoint()
if e != nil {
log("ep", 4, "fetched a delayed endpoint: \"%v\"", e)
return e, 0
}
el := endpoints.Front()
if el == nil {
if waitTime == 0 {
log("ep", 1, "out of endpoints")
return nil, 0
}
log("ep", 4, "all endpoints are delayed, waiting for %v", waitTime)
return nil, waitTime
}
endpoints.MoveToBack(el)
e = el.Value.(*Endpoint)
log("ep", 4, "fetched a normal endpoint: \"%v\"", e)
return e, 0
}
type Address struct { type Address struct {
ip string // TODO: switch to a static 16-byte array ip string // TODO: switch to a static 16-byte array
port int port int
@ -28,7 +88,8 @@ const (
type Endpoint struct { type Endpoint struct {
addr Address // IP address of an endpoint addr Address // IP address of an endpoint
loginPos, passwordPos SourcePos // login/password cursors loginPos SourcePos
passwordPos SourcePos // login/password cursors
listElement *list.Element // position in list listElement *list.Element // position in list
state EndpointState // which state an endpoint is in state EndpointState // which state an endpoint is in
@ -151,28 +212,10 @@ func (e *Endpoint) Delay(addTime time.Duration) {
} }
} }
// MigrateToNormal moves an Endpoint to a normal queue.
// Endpoint mutex is assumed to be taken.
func (e *Endpoint) MigrateToNormal() {
endpointMutex.Lock()
defer endpointMutex.Unlock()
if e.normalList != nil {
log("ep", 5, "cannot migrate endpoint \"%v\" to normal list: already in the list", e)
} else {
log("ep", 5, "migrating endpoint \"%v\" to normal list", e)
e.normalList = endpoints.PushBack(e)
if e.delayedList != nil {
delayedEndpoints.Remove(e.delayedList)
e.delayedList = nil
}
}
}
// SkipLogin gets the endpoint's current login, // SkipLogin gets the endpoint's current login,
// compares it with user-defined login and skips (advances) it if // compares it with user-defined login and skips (advances) it if
// both logins are equal. // both logins are equal.
func (e *Endpoint) SkipLogin(login) { func (e *Endpoint) SkipLogin(login string) {
// attempt to fetch next login // attempt to fetch next login
curLogin, empty := SrcLogin.FetchOne(&e.loginPos, false) curLogin, empty := SrcLogin.FetchOne(&e.loginPos, false)
if curLogin == login && !empty { // this login has not yet been exhausted? if curLogin == login && !empty { // this login has not yet been exhausted?
@ -278,12 +321,12 @@ func (e *Endpoint) Bad() {
e.consecutiveProtoErrors = 0 e.consecutiveProtoErrors = 0
// The endpoint may be in delayed queue, so push it back to the normal queue. // The endpoint may be in delayed queue, so push it back to the normal queue.
e.MigrateToNormal() e.SetState(ES_Normal)
} }
// Good is an event handler that gets called when // Good is an event handler that gets called when
// an authentication attempt to an Endpoint succeeds. // an authentication attempt to an Endpoint succeeds.
func (e *Endpoint) Good(login) { func (e *Endpoint) Good(login string) {
e.mutex.Lock() e.mutex.Lock()
defer e.mutex.Unlock() defer e.mutex.Unlock()
e.consecutiveProtoErrors = 0 e.consecutiveProtoErrors = 0
@ -291,7 +334,7 @@ func (e *Endpoint) Good(login) {
if !getParamSwitch("keep-endpoint-on-good") { if !getParamSwitch("keep-endpoint-on-good") {
e.Delete() e.Delete()
} else { } else {
e.MigrateToNormal() e.SetState(ES_Normal)
e.SkipLogin(login) e.SkipLogin(login)
} }
} }
@ -372,257 +415,24 @@ func (e *Endpoint) Exhausted() {
func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) { func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) {
currentTime := time.Now() currentTime := time.Now()
if delayedEndpoints.Empty() { if delayedEndpoints.Len() == 0 {
log("ep", 5, "delayed endpoint list is empty") log("ep", 5, "delayed endpoint list is empty")
return nil, 0 return nil, 0
} }
it := delayedEndpoints.IteratorAt(delayedEndpoints.Left()) minWaitTime := time.Time{}
for {
k, v := it.Key().(time.Time), it.Value().(*Endpoint)
if v == nil {
panic("delayed endpoint list contains an empty endpoint")
return nil, 0
}
if k.After(currentTime) {
log("ep", 5, "no delayed endpoints can be processed at this time")
return nil, k.Sub(currentTime)
}
if k.Before(v.delayUntil) {
log("ep", 5, "delayed endpoint was re-delayed: removing lingering definition")
defer delayedEndpoints.Remove(k)
it.Next()
continue
}
if v.delayUntil.IsZero() {
log("ep", 5, "delayed endpoint is already in normal queue: removing lingering definition")
defer delayedEndpoints.Remove(k)
it.Next()
continue
}
defer delayedEndpoints.Remove(k)
return v, 0
}
log("ep", 5, "delayed endpoint list was holding only lingering definitions and is now empty")
return nil, 0
}
// FetchEndpoint retrieves an endpoint: first, a delayed list is queried,
// then, if nothing is found, a normal list is searched,
// and (TODO) if this list is empty or will soon be emptied,
// a new batch of endpoints gets created.
func FetchEndpoint() (e *Endpoint, waitTime time.Duration) {
globalEndpointMutex.Lock()
defer globalEndpointMutex.Unlock()
log("ep", 4, "fetching an endpoint")
e, waitTime = GetDelayedEndpoint()
if e != nil {
log("ep", 4, "fetched a delayed endpoint: \"%v\"", e)
return e, 0
}
el := endpoints.Front()
if el == nil {
if waitTime == 0 {
log("ep", 1, "out of endpoints")
return nil, 0
}
log("ep", 4, "all endpoints are delayed, waiting for %v", waitTime)
return nil, waitTime
}
endpoints.MoveToBack(el)
e = el.Value.(*Endpoint)
log("ep", 4, "fetched an endpoint: \"%v\"", e)
return e, 0
}
// ---
// ---
// ---
// Safety feature, to avoid expanding subnets into a huge amount of IPs.
const maxNetmaskSize = 22 // expands into /10 for IPv4
// RegisterEndpoint builds an Endpoint and puts it to a global list of endpoints.
func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int {
for _, port := range ports {
ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}}
ep.loginPos.Reset()
ep.passwordPos.Reset()
ep.listElement = endpoints.PushBack(&ep)
log("ep", 3, "registered endpoint: %v", &ep)
}
return len(ports)
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
// parseCIDR registers multiple endpoints from a CIDR netmask.
func parseCIDR(ip string, ports []int, isIPv6 bool) int {
na, nm, err := net.ParseCIDR(ip)
if err != nil {
log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error())
return 0
}
mask, maskBits := nm.Mask.Size()
if mask < maskBits-maxNetmaskSize {
log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize)
return 0
}
curHost := 0
maxHost := 1<<(maskBits-mask) - 1
numParsed := 0
strict := getParamSwitch("strict-subnets")
log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1)
for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) {
if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 {
log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String())
} else {
numParsed += RegisterEndpoint(expIP.String(), ports, isIPv6)
}
curHost++
}
return numParsed
}
// parseIPOrCIDR expands plain IP or CIDR to multiple endpoints.
func parseIPOrCIDR(ip string, ports []int, isIPv6 bool) int {
// ip may be a domain name, a CIDR subnet or an IP address
// CIDR subnets must be expanded to plain IPs
if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet
return parseCIDR(ip, ports, isIPv6)
} else if strings.Count(ip, "/") > 1 { // invalid CIDR notation
log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip)
return 0
} else { // otherwise, just register
return RegisterEndpoint(ip, ports, isIPv6)
}
}
// extractIPAndPort extracts all endpoint components.
func extractIPAndPort(str string) (ip string, port int, err error) {
var portString string
ip, portString, err = net.SplitHostPort(str)
if err != nil {
return "", 0, err
}
port, err = strconv.Atoi(portString)
if err != nil {
return "", 0, err
}
if port <= 0 || port > 65535 {
return "", 0, errors.New("invalid port: " + strconv.Itoa(port))
}
return ip, port, nil
}
// ParseEndpoints takes a string slice of IPs/CIDR subnets and converts it to a list of endpoints.
func ParseEndpoints(source []string) {
log("ep", 1, "parsing endpoints")
totalIPv6Skipped := 0
numParsed := 0
for _, str := range source {
if !strings.Contains(str, ":") {
// no ":": this is an ipv4/dn without port,
// parse it with all known ports
numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), false)
} else {
// either ipv4/dn with port, or ipv6 with/without port
isIPv6 := strings.Count(str, ":") > 1
if isIPv6 && getParamSwitch("no-ipv6") {
totalIPv6Skipped++
continue
}
if !strings.Contains(str, "]:") && strings.Contains(str, "::") {
// ipv6 without port
numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), true)
continue
}
ip, port, err := extractIPAndPort(str)
if err != nil {
log("ep", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error())
continue
}
ports := []int{port} for e := delayedEndpoints.Front(); e != nil; e = e.Next() {
// append all default ports dt := e.Value.(*Endpoint)
if getParamSwitch("append-default-ports") { if minWaitTime.IsZero() || (dt.delayUntil.Before(minWaitTime) && dt.delayUntil.After(currentTime)) {
for _, port2 := range getParamIntSlice("port") { minWaitTime = dt.delayUntil
if port != port2 {
ports = append(ports, port2)
}
}
} }
numParsed += parseIPOrCIDR(ip, ports, isIPv6) if dt.delayUntil.Before(currentTime) {
} delayedEndpoints.Remove(e)
return dt, 0
} }
logIf(totalIPv6Skipped > 0, "ep", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped)
log("ep", 1, "finished parsing endpoints: parsed %v out of total %v", numParsed, endpoints.Len())
} }
func init() { return nil, minWaitTime.Sub(currentTime)
endpoints = list.New()
delayedEndpoints = list.New()
registerParam("port", []int{8291}, "one or more default ports")
registerParam("max-aps", 5, "maximum number of attempts per second for an endpoint")
registerSwitch("no-ipv6", "skip IPv6 entries")
registerSwitch("append-default-ports", "always append default ports even for targets in host:port format")
registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets")
registerSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found")
registerParam("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value")
registerParam("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection")
registerParam("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections")
registerParam("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections")
registerParam("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value")
registerParam("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors")
registerParam("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors")
registerParam("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value")
registerParam("max-read-errors", 20, "always remove endpoint after this many consecutive read errors")
registerParam("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors")
registerParam("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond")
registerParam("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error")
registerParam("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error")
} }

@ -0,0 +1,154 @@
package main
import (
"errors"
"net"
"strconv"
"strings"
)
// Safety feature, to avoid expanding subnets into a huge amount of IPs.
const maxNetmaskSize = 22 // expands into /10 for IPv4
// RegisterEndpoint builds an Endpoint and puts it to a global list of endpoints.
func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int {
for _, port := range ports {
ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}}
ep.loginPos.Reset()
ep.passwordPos.Reset()
ep.listElement = endpoints.PushBack(&ep)
log("ep", 3, "registered endpoint: %v", &ep)
}
return len(ports)
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
// parseCIDR registers multiple endpoints from a CIDR netmask.
func parseCIDR(ip string, ports []int, isIPv6 bool) int {
na, nm, err := net.ParseCIDR(ip)
if err != nil {
log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error())
return 0
}
mask, maskBits := nm.Mask.Size()
if mask < maskBits-maxNetmaskSize {
log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize)
return 0
}
curHost := 0
maxHost := 1<<(maskBits-mask) - 1
numParsed := 0
strict := getParamSwitch("strict-subnets")
log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1)
for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) {
if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 {
log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String())
} else {
numParsed += RegisterEndpoint(expIP.String(), ports, isIPv6)
}
curHost++
}
return numParsed
}
// parseIPOrCIDR expands plain IP or CIDR to multiple endpoints.
func parseIPOrCIDR(ip string, ports []int, isIPv6 bool) int {
// ip may be a domain name, a CIDR subnet or an IP address
// CIDR subnets must be expanded to plain IPs
if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet
return parseCIDR(ip, ports, isIPv6)
} else if strings.Count(ip, "/") > 1 { // invalid CIDR notation
log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip)
return 0
} else { // otherwise, just register
return RegisterEndpoint(ip, ports, isIPv6)
}
}
// extractIPAndPort extracts all endpoint components.
func extractIPAndPort(str string) (ip string, port int, err error) {
var portString string
ip, portString, err = net.SplitHostPort(str)
if err != nil {
return "", 0, err
}
port, err = strconv.Atoi(portString)
if err != nil {
return "", 0, err
}
if port <= 0 || port > 65535 {
return "", 0, errors.New("invalid port: " + strconv.Itoa(port))
}
return ip, port, nil
}
// ParseEndpoints takes a string slice of IPs/CIDR subnets and converts it to a list of endpoints.
func ParseEndpoints(source []string) {
log("ep", 1, "parsing endpoints")
totalIPv6Skipped := 0
numParsed := 0
for _, str := range source {
if !strings.Contains(str, ":") {
// no ":": this is an ipv4/dn without port,
// parse it with all known ports
numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), false)
} else {
// either ipv4/dn with port, or ipv6 with/without port
isIPv6 := strings.Count(str, ":") > 1
if isIPv6 && getParamSwitch("no-ipv6") {
totalIPv6Skipped++
continue
}
if !strings.Contains(str, "]:") && strings.Contains(str, "::") {
// ipv6 without port
numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), true)
continue
}
ip, port, err := extractIPAndPort(str)
if err != nil {
log("ep", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error())
continue
}
ports := []int{port}
// append all default ports
if getParamSwitch("append-default-ports") {
for _, port2 := range getParamIntSlice("port") {
if port != port2 {
ports = append(ports, port2)
}
}
}
numParsed += parseIPOrCIDR(ip, ports, isIPv6)
}
}
logIf(totalIPv6Skipped > 0, "ep", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped)
log("ep", 1, "finished parsing endpoints: parsed %v out of total %v", numParsed, endpoints.Len())
}

@ -33,7 +33,6 @@ func (el *M2Element) String() string {
return fmt.Sprintf("code=%v,type=%T,value=%v", el.code, el.value, el.value) return fmt.Sprintf("code=%v,type=%T,value=%v", el.code, el.value, el.value)
} }
type M2Message struct { type M2Message struct {
el []M2Element el []M2Element
} }
@ -120,7 +119,6 @@ func (m2 *M2Message) Bytes() []byte {
return append(header, res...) return append(header, res...)
} }
func (m2 *M2Message) ParseM2Element(buf io.Reader) error { func (m2 *M2Message) ParseM2Element(buf io.Reader) error {
var codeAndType uint32 var codeAndType uint32
err := binary.Read(buf, binary.LittleEndian, &codeAndType) err := binary.Read(buf, binary.LittleEndian, &codeAndType)
@ -246,32 +244,22 @@ func ParseM2Messages(src []byte) (messages []M2Message, err error) {
} }
} }
log("lw", 3, "m2 eof after %v messages", len(messages)) log("lw", 3, "m2 eof after %v messages", len(messages))
return messages, nil return messages, nil
} }
type LegacyWinbox struct { type LegacyWinbox struct {
task *Task task *Task
conn *Connection
stage int stage int
m2 []M2Message m2 []M2Message
} }
func NewLegacyWinbox(task *Task) *LegacyWinbox { func NewLegacyWinbox(task *Task, conn *Connection) *LegacyWinbox {
lw := LegacyWinbox{task: task, stage: -1, m2: []M2Message{}} lw := LegacyWinbox{task: task, conn: conn, stage: -1, m2: []M2Message{}}
return &lw return &lw
} }
// req1 // req1
func (lw *LegacyWinbox) MTReqList() []byte { func (lw *LegacyWinbox) MTReqList() []byte {
m2 := NewM2Message() m2 := NewM2Message()
@ -296,7 +284,6 @@ func (lw *LegacyWinbox) MTGetSid(m2 []M2Message) *M2Element {
return nil return nil
} }
// req2 // req2
func (lw *LegacyWinbox) MTReqChallenge(sid *M2Element) []byte { func (lw *LegacyWinbox) MTReqChallenge(sid *M2Element) []byte {
m2 := NewM2Message() m2 := NewM2Message()
@ -321,7 +308,6 @@ func (lw *LegacyWinbox) MTGetSalt(m2 []M2Message) M2Hash {
return "" return ""
} }
// req3 // req3
func (lw *LegacyWinbox) MTReqAuth(sid *M2Element, login, digest, salt string) []byte { func (lw *LegacyWinbox) MTReqAuth(sid *M2Element, login, digest, salt string) []byte {
m2 := NewM2Message() m2 := NewM2Message()
@ -357,27 +343,22 @@ func (lw *LegacyWinbox) MTGetResult(m2 []M2Message) (res bool, err error) {
return false, errors.New("no auth marker found") return false, errors.New("no auth marker found")
} }
func (lw *LegacyWinbox) SendRecv(buf []byte) (res []byte, err error) { func (lw *LegacyWinbox) SendRecv(buf []byte) (res []byte, err error) {
_, err = lw.task.conn.Write(buf) err = lw.conn.Send(buf)
if err != nil { if err != nil {
log("lw", 1, "failed to send: %v", err.Error()) log("lw", 1, "failed to send: %v", err.Error())
return nil, err return nil, err
} }
resp := make([]byte, 1024) resp, err := lw.conn.Recv()
n, err := lw.task.conn.Read(resp)
if err != nil { if err != nil {
log("lw", 1, "failed to recv: %v", err.Error()) log("lw", 1, "failed to recv: %v", err.Error())
return nil, err return nil, err
} }
return resp[:n], nil return resp, nil
} }
func (lw *LegacyWinbox) TryLogin() (res bool, err error) { func (lw *LegacyWinbox) TryLogin() (res bool, err error) {
log("lw", 2, "login: stage 1, req_list") log("lw", 2, "login: stage 1, req_list")
r1, err := lw.SendRecv(lw.MTReqList()) r1, err := lw.SendRecv(lw.MTReqList())

@ -1,13 +1,22 @@
package main package main
// log.go: logging
import ( import (
"fmt" "fmt"
"os" "os"
"strings" "strings"
) )
func init() {
registerParam("log-level", 0, "max log level, useful for debugging. -1 logs everything")
registerParamWithCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap)
registerParamWithCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap)
registerParamHidden("log-module-map", map[string]bool{})
}
func shouldLog(facility string, level, maxLevel int) bool { func shouldLog(facility string, level, maxLevel int) bool {
moduleMap := CfgGet("log-module-map").(map[string]bool) moduleMap := getParam[map[string]bool]("log-module-map")
logModule, ok := moduleMap[strings.ToLower(facility)] logModule, ok := moduleMap[strings.ToLower(facility)]
if ok { if ok {
@ -22,7 +31,7 @@ func shouldLog(facility string, level, maxLevel int) bool {
} }
func log(facility string, level int, s string, params ...interface{}) { func log(facility string, level int, s string, params ...interface{}) {
maxLevel := CfgGetInt("log-level") maxLevel := getParamInt("log-level")
if !shouldLog(facility, level, maxLevel) { if !shouldLog(facility, level, maxLevel) {
return return
} }
@ -64,8 +73,8 @@ func failIf(condition bool, s string, params ...interface{}) {
} }
func updateModuleMap() { func updateModuleMap() {
logModules := CfgGet("log-modules").([]string) logModules := getParamStringSlice("log-modules")
noLogModules := CfgGet("no-log-modules").([]string) noLogModules := getParamStringSlice("no-log-modules")
newMap := map[string]bool{} newMap := map[string]bool{}
@ -76,16 +85,9 @@ func updateModuleMap() {
for _, module := range noLogModules { for _, module := range noLogModules {
module = strings.ToLower(module) module = strings.ToLower(module)
failIf(newMap[module] == true, "log module \"%v\" is defined both in log-modules and no-log-modules", module) failIf(newMap[module], "log module \"%v\" is defined both in log-modules and no-log-modules", module)
newMap[module] = false newMap[module] = false
} }
CfgSet("log-module-map", newMap) setParam("log-module-map", newMap)
}
func init() {
CfgRegister("log-level", 0, "max log level, useful for debugging. -1 logs everything")
CfgRegisterCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap)
CfgRegisterCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap)
CfgRegisterHidden("log-module-map", map[string]bool{})
} }

@ -1,7 +1,8 @@
package main package main
// math.go: mathematical routines
import ( import (
_ "fmt"
"math" "math"
) )

@ -4,13 +4,13 @@ func main() {
log("main", 0, "mtbf: Mikrotik RouterOS bruteforce | v1.0.1") log("main", 0, "mtbf: Mikrotik RouterOS bruteforce | v1.0.1")
parseAppConfig() parseAppConfig()
OpenOutFile() go ResultService()
defer CloseOutFile() defer EndResults()
LoadSources() LoadSources()
defer CloseSources() defer CloseSources()
wg := InitializeThreads() ThreadService()
WaitForThreads(wg)
log("main", 0, "finished") log("main", 0, "finished")
} }

@ -1,25 +1,50 @@
package main package main
// results.go: saves good results to a file or console
import ( import (
"fmt" "fmt"
"os" "os"
) )
func init() {
registerParam("out-file", "good.txt", "results will be saved in this file")
registerAlias("o", "out-file")
}
var outFile *os.File var outFile *os.File
var resultChannel chan *Task
func RegisterResult(t *Task, good bool) { func ResultService() {
if good { openResultFile()
log("res", 0, "****************\n******** OK: %v %v %v\n****************", t.e.String(), t.login, t.password) defer closeResultFile()
resultChannel = make(chan *Task, 128)
for task := range resultChannel {
if outFile != nil { if outFile != nil {
fmt.Fprintf(outFile, "%v\t%v\t%v\n", t.e.String(), t.login, t.password) fmt.Fprintf(outFile, "%v\n", task.String())
}
}
} }
func RegisterResult(task *Task, good bool) {
if !good {
log("res", 1, "bad: %v", task.String())
} else { } else {
log("res", 1, "bad: %v %v %v", t.e.String(), t.login, t.password) log("res", 0, "good: %v", task.String())
if outFile != nil {
resultChannel <- task
}
}
} }
func EndResults() {
close(resultChannel)
} }
func OpenOutFile() { func openResultFile() {
fileName := CfgGetString("out-file") fileName := getParamString("out-file")
if fileName == "" { if fileName == "" {
log("out", 0, "WARNING: out-file is not specified, results will only be logged in console") log("out", 0, "WARNING: out-file is not specified, results will only be logged in console")
outFile = nil outFile = nil
@ -28,19 +53,17 @@ func OpenOutFile() {
outFile, err = os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) outFile, err = os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil { if err != nil {
fail("error opening output file \"%v\": %v", fileName, err.Error()) fail("error opening result file \"%v\": %v", fileName, err.Error())
} }
log("out", 2, "opened output file \"%v\"", fileName) log("out", 2, "opened result file \"%v\"", fileName)
} }
} }
func CloseOutFile() { func closeResultFile() {
if outFile != nil {
outFile.Close() outFile.Close()
outFile = nil outFile = nil
log("out", 2, "closed result file")
} }
func init() {
CfgRegister("out-file", "good.txt", "results will be saved in this file")
CfgRegisterAlias("o", "out-file")
} }

@ -2,12 +2,11 @@ package main
import ( import (
"errors" "errors"
"net"
) )
// TODO: multiple services... // TODO: multiple services...
func TryLogin(task *Task, conn net.Conn) (res bool, err error) { func TryLogin(task *Task, conn *Connection) (res bool, err error) {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
log("srv", 1, "fatal error (panic) in service handler: %v", r) log("srv", 1, "fatal error (panic) in service handler: %v", r)

@ -8,64 +8,114 @@ import (
"sync" "sync"
) )
type Source struct { // SrcIP, SrcLogin and SrcPassword represent different source types.
name string // name of this source var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"}
var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"}
var SrcPassword Source = Source{name: "password", plainParmName: "password",
filesParmName: "password-file", transform: func(item string) (res string, err error) {
if getParamSwitch("no-password-trim") {
return item, nil
} else {
return strings.TrimSpace(item), nil
}
}}
plain []string // sources from commandline func init() {
contents []string // sources from files registerParam("ip", []string{}, "IPs or subnets in CIDR notation")
registerParam("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)")
registerParam("login", []string{}, "one or more logins")
registerParam("login-file", []string{}, "paths to files with logins (one entry per line)")
registerParam("password", []string{}, "one or more passwords")
registerParam("password-file", []string{}, "paths to files with passwords (one entry per line)")
files []*os.File // file pointers registerSwitch("add-empty-password", "insert an empty password to the password list")
fileNames []string // file names registerSwitch("no-password-trim", "preserve leading and trailing spaces in passwords")
plainParmName string // name of "plain" commandline parameter registerSwitch("logins-first", "increment logins before passwords")
filesParmName string // name of "files" commandline parameter registerSwitch("file-contents-first", "try to go through source files first, defer commandline args for later")
}
transform func(item string) (string, error) // optional transformation function // LoadSources loads contents for all sources.
func LoadSources() {
log("src", 1, "loading sources")
fetchMutex sync.Mutex // sync mutex var wg sync.WaitGroup
wg.Add(3)
go SrcIP.LoadSource(&wg)
go SrcLogin.LoadSource(&wg)
go SrcPassword.LoadSource(&wg)
wg.Wait()
SrcIP.ReportLoaded()
SrcLogin.ReportLoaded()
SrcPassword.ReportLoaded()
ParseEndpoints(SrcIP.plain)
ParseEndpoints(SrcIP.contents)
log("src", 1, "ok: finished loading sources")
} }
// both -1: exhausted // CloseSources closes all source files.
// both 0: not started yet func CloseSources() {
type SourcePos struct { log("src", 1, "closing sources")
plainIdx int
contentIdx int SrcIP.CloseSource()
SrcLogin.CloseSource()
SrcPassword.CloseSource()
log("src", 1, "ok: finished closing sources")
} }
// String converts a SourcePos to its string representation. // LoadSource fills a Source with data (from commandline and from files).
func (pos *SourcePos) String() string { func (src *Source) LoadSource(wg *sync.WaitGroup) {
return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx) if wg != nil {
defer wg.Done()
} }
// Exhausted checks if a SourcePos can no longer produce any sources. if src.name == "password" && getParamSwitch("add-empty-password") {
func (pos *SourcePos) Exhausted() bool { src.plain = append(src.plain, "")
return pos.plainIdx == -1 && pos.contentIdx == -1
} }
// Reset moves a SourcePos to its starting position. src.ParsePlain()
func (pos *SourcePos) Reset() {
pos.plainIdx = 0 src.OpenFiles()
pos.contentIdx = 0 defer src.CloseSource()
log("src", 3, "resetting source pos")
src.ParseFiles()
failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src, src.plainParmName, src.filesParmName)
} }
// passwordTransform is a transformation function for a password. // CloseSource closes all files for a Source.
func passwordTransform(item string) (res string, err error) { func (src *Source) CloseSource() {
if getParamSwitch("no-password-trim") { l := len(src.files)
return item, nil for _, file := range src.files {
} else { if file != nil {
return strings.TrimSpace(item), nil file.Close()
} }
} }
// SrcIP, SrcLogin and SrcPassword represent different sources. src.files = nil
var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"} src.fileNames = nil
var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"} log("src", 1, "closed all %v %v files", l, src)
var SrcPassword Source = Source{name: "password", plainParmName: "password", }
filesParmName: "password-file", transform: passwordTransform}
// String converts a Source to its string representation. // OpenFiles opens all files for a Source.
func (src *Source) String() string { func (src *Source) OpenFiles() {
return src.name fileNames := getParamStringSlice(src.filesParmName)
for _, fileName := range fileNames {
f, err := os.Open(fileName)
if err != nil {
fail("error opening source file \"%v\": %v", fileName, err.Error())
}
src.files = append(src.files, f)
src.fileNames = append(src.fileNames, fileName)
}
if len(src.files) > 0 {
log("src", 1, "opened %v %v files", len(src.files), src)
}
} }
// ValidateAndTransformItem attempts to validate a source item // ValidateAndTransformItem attempts to validate a source item
@ -84,51 +134,6 @@ func (src *Source) ValidateAndTransformItem(item string) (res string, err error)
} }
} }
// FetchFromSlice retrieves an item from a string slice and optionally increments its current position.
func (src *Source) FetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) {
if *idx == -1 { // exhausted
log("src", 5, "fetch %v from %v: idx is -1, return empty", src, name)
return "", true
}
if *idx >= len(slice) {
log("src", 5, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src, name, *idx, len(slice))
*idx = -1
return "", true
}
res = slice[*idx]
log("src", 5, "fetch %v from %v: ok, got %v at idx %v", src, name, res, *idx)
if inc {
*idx = *idx + 1
log("src", 5, "fetch %v from %v: incrementing idx to %v", src, name, *idx)
}
return res, false
}
// FetchOne retrieves an item from a Source with a specified SourcePos.
func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) {
src.fetchMutex.Lock()
defer src.fetchMutex.Unlock()
if getParamSwitch("file-contents-first") {
res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc)
if empty {
res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc)
}
} else {
res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc)
if empty {
res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc)
}
}
logIf(empty, "src", 2, "exhausted source %v for pos %v", src, pos.String())
return res, empty
}
// ParsePlain parses commandline parameters for a Source. // ParsePlain parses commandline parameters for a Source.
func (src *Source) ParsePlain() { func (src *Source) ParsePlain() {
for _, plain := range getParamStringSlice(src.plainParmName) { for _, plain := range getParamStringSlice(src.plainParmName) {
@ -145,26 +150,7 @@ func (src *Source) ParsePlain() {
} }
} }
// OpenFiles opens all files for a Source. // ParseFiles parses files for a Source.
func (src *Source) OpenFiles() {
fileNames := getParamStringSlice(src.filesParmName)
for _, fileName := range fileNames {
f, err := os.Open(fileName)
if err != nil {
fail("error opening source file \"%v\": %v", fileName, err.Error())
}
src.files = append(src.files, f)
src.fileNames = append(src.fileNames, fileName)
}
if len(src.files) > 0 {
log("src", 1, "opened %v %v files", len(src.files), src)
}
}
// ParseFiles parses all files for a Source.
func (src *Source) ParseFiles() { func (src *Source) ParseFiles() {
for i, file := range src.files { for i, file := range src.files {
fileName := src.fileNames[i] fileName := src.fileNames[i]
@ -193,95 +179,101 @@ func (src *Source) ParseFiles() {
} }
} }
// FailIfEmpty throws an exception if a source is empty.
func (src *Source) FailIfEmpty() {
failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src, src.plainParmName, src.filesParmName)
}
// ReportLoaded prints a console message about the number of loaded items for a Source. // ReportLoaded prints a console message about the number of loaded items for a Source.
func (src *Source) ReportLoaded() { func (src *Source) ReportLoaded() {
log("src", 0, "loaded %vs: %v items from commandline and %v items from files", src, len(src.plain), len(src.contents)) log("src", 0, "loaded %vs: %v items from commandline and %v items from files", src, len(src.plain), len(src.contents))
} }
// LoadSource fills a Source with data (from commandline and from files). type Source struct {
func (src *Source) LoadSource(wg *sync.WaitGroup) { name string // name of this source
if wg != nil {
defer wg.Done()
}
if src.name == "password" && getParamSwitch("add-empty-password") { plain []string // sources from commandline
src.plain = append(src.plain, "") contents []string // sources from files
}
src.ParsePlain() files []*os.File // file pointers
fileNames []string // file names
plainParmName string // name of "plain" commandline parameter
filesParmName string // name of "files" commandline parameter
src.OpenFiles() transform func(item string) (string, error) // optional transformation function
defer src.CloseSource()
src.ParseFiles() fetchMutex sync.Mutex // sync mutex
src.FailIfEmpty()
} }
// CloseSource closes all files for a Source. // String converts a Source to its string representation.
func (src *Source) CloseSource() { func (src *Source) String() string {
l := len(src.files) return src.name
for _, file := range src.files {
if file != nil {
file.Close()
} }
// FetchFromSlice retrieves an item from a string slice and optionally increments its current position.
func (src *Source) FetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) {
if *idx == -1 { // exhausted
log("src", 5, "fetch %v from %v: idx is -1, return empty", src, name)
return "", true
} }
src.files = nil if *idx >= len(slice) {
src.fileNames = nil log("src", 5, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src, name, *idx, len(slice))
log("src", 1, "closed all %v %v files", l, src) *idx = -1
return "", true
} }
// --- res = slice[*idx]
// --- log("src", 5, "fetch %v from %v: ok, got %v at idx %v", src, name, res, *idx)
// ---
// LoadSources loads contents for all sources, both from commandline and from files. if inc {
func LoadSources() { *idx = *idx + 1
log("src", 1, "loading sources") log("src", 5, "fetch %v from %v: incrementing idx to %v", src, name, *idx)
}
var wg sync.WaitGroup return res, false
wg.Add(3) }
go SrcIP.LoadSource(&wg)
go SrcLogin.LoadSource(&wg)
go SrcPassword.LoadSource(&wg)
wg.Wait()
SrcIP.ReportLoaded() // FetchOne retrieves an item from a Source with a specified SourcePos.
SrcLogin.ReportLoaded() func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) {
SrcPassword.ReportLoaded() src.fetchMutex.Lock()
defer src.fetchMutex.Unlock()
ParseEndpoints(SrcIP.plain) if getParamSwitch("file-contents-first") {
ParseEndpoints(SrcIP.contents) res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc)
if empty {
res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc)
}
} else {
res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc)
if empty {
res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc)
}
}
log("src", 1, "ok: finished loading sources") logIf(empty, "src", 2, "exhausted source %v for pos %v", src, pos.String())
return res, empty
} }
// CloseSources closes all source files. // ---
func CloseSources() { // ---
log("src", 1, "closing sources") // ---
SrcIP.CloseSource() // both -1: exhausted
SrcLogin.CloseSource() // both 0: not started yet
SrcPassword.CloseSource() type SourcePos struct {
plainIdx int
contentIdx int
}
log("src", 1, "ok: finished closing sources") // String converts a SourcePos to its string representation.
func (pos *SourcePos) String() string {
return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx)
} }
func init() { // Exhausted checks if a SourcePos can no longer produce any sources.
registerParam("ip", []string{}, "IPs or subnets in CIDR notation") func (pos *SourcePos) Exhausted() bool {
registerParam("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") return pos.plainIdx == -1 && pos.contentIdx == -1
registerParam("login", []string{}, "one or more logins") }
registerParam("login-file", []string{}, "paths to files with logins (one entry per line)")
registerParam("password", []string{}, "one or more passwords")
registerParam("password-file", []string{}, "paths to files with passwords (one entry per line)")
registerSwitch("add-empty-password", "insert an empty password to the password list") // Reset moves a SourcePos to its starting position.
registerSwitch("no-password-trim", "preserve leading and trailing spaces in passwords") func (pos *SourcePos) Reset() {
registerSwitch("logins-first", "increment logins before passwords") pos.plainIdx = 0
registerSwitch("file-contents-first", "try to go through source files first, defer commandline args for later") pos.contentIdx = 0
log("src", 3, "resetting source pos")
} }

@ -1,13 +1,22 @@
package main package main
import ( import (
rbt "github.com/emirpasic/gods/trees/redblacktree" "container/list"
rbtUtils "github.com/emirpasic/gods/utils"
"net"
"sync" "sync"
"time" "time"
) )
// deferredTasks is a list of tasks that were deferred for processing to a later time.
// This usually happens due to connection errors, protocol errors or per-endpoint limits.
var deferredTasks *list.List
// taskMutex is a mutex for safe handling of deferred task list.
var taskMutex sync.Mutex
func init() {
deferredTasks = list.New()
}
// TaskEvent represents all events that can be issued on a Task. // TaskEvent represents all events that can be issued on a Task.
type TaskEvent int type TaskEvent int
@ -52,26 +61,13 @@ func (ev TaskEvent) String() string {
// Every Task is linked to an Endpoint. // Every Task is linked to an Endpoint.
type Task struct { type Task struct {
e *Endpoint e *Endpoint
login, password string login string
password string
deferUntil time.Time deferUntil time.Time
numDeferrals int numDeferrals int
listElement *list.Element // position in list
// this should not be in Task struct!
conn net.Conn
// ??? do we even need this here
good bool
} }
// maxSafeThreads is a safeguard to prevent creation of too many threads at once.
const maxSafeThreads = 5000
// deferredTasks is a list of tasks that were deferred for processing to a later time.
// This usually happens due to connection errors, protocol errors or per-endpoint limits.
var deferredTasks *rbt.Tree
// taskMutex is a mutex for safe handling of RB tree.
var taskMutex sync.Mutex
// String returns a string representation of a Task. // String returns a string representation of a Task.
func (task *Task) String() string { func (task *Task) String() string {
if task == nil { if task == nil {
@ -86,12 +82,7 @@ func (task *Task) Defer(addTime time.Duration) {
task.deferUntil = time.Now().Add(addTime) task.deferUntil = time.Now().Add(addTime)
task.numDeferrals++ task.numDeferrals++
// tell the endpoint that we got deferred, maxDeferrals := getParamInt("task-max-deferrals")
// so it won't be selected until the deferral time has passed
// task.e.SetDeferralTime(task.deferUntil)
// FIXME: this isn't needed, endpoints can handle their own delays
maxDeferrals := CfgGetInt("task-max-deferrals")
if maxDeferrals != -1 && task.numDeferrals >= maxDeferrals { if maxDeferrals != -1 && task.numDeferrals >= maxDeferrals {
log("task", 5, "giving up on task \"%v\" because it has exhausted its deferral limit (%v)", task, maxDeferrals) log("task", 5, "giving up on task \"%v\" because it has exhausted its deferral limit (%v)", task, maxDeferrals)
return return
@ -102,7 +93,9 @@ func (task *Task) Defer(addTime time.Duration) {
taskMutex.Lock() taskMutex.Lock()
defer taskMutex.Unlock() defer taskMutex.Unlock()
deferredTasks.Put(task.deferUntil, task) if task.listElement == nil {
task.listElement = deferredTasks.PushBack(task)
}
} }
// EventWithParm tells a Task (and its underlying Endpoint) that // EventWithParm tells a Task (and its underlying Endpoint) that
@ -115,21 +108,21 @@ func (task *Task) EventWithParm(event TaskEvent, parm any) bool {
return true // do not process generic events return true // do not process generic events
} }
res := task.e.EventWithParm(event, parm) // notify the endpoint first endpointOk := task.e.EventWithParm(event, parm) // notify the endpoint first
switch event { switch event {
// on these events, defer a Task only if its Endpoint is being kept // on these events, defer a Task only if its Endpoint is being kept
case TE_NoResponse: case TE_NoResponse:
if res { if endpointOk {
task.Defer(CfgGetDurationMS("no-response-delay-ms")) task.Defer(getParamDurationMS("no-response-delay-ms"))
} }
case TE_ReadFailed: case TE_ReadFailed:
if res { if endpointOk {
task.Defer(CfgGetDurationMS("read-error-delay-ms")) task.Defer(getParamDurationMS("read-error-delay-ms"))
} }
case TE_ProtocolError: case TE_ProtocolError:
if res { if endpointOk {
task.Defer(CfgGetDurationMS("protocol-error-delay-ms")) task.Defer(getParamDurationMS("protocol-error-delay-ms"))
} }
// report about a bad/good auth result // report about a bad/good auth result
@ -144,7 +137,7 @@ func (task *Task) EventWithParm(event TaskEvent, parm any) bool {
time.Sleep(parm.(time.Duration)) time.Sleep(parm.(time.Duration))
} }
return res return endpointOk
} }
// Event is a parameterless version of EventWithParm. // Event is a parameterless version of EventWithParm.
@ -156,42 +149,26 @@ func (task *Task) Event(event TaskEvent) bool {
func GetDeferredTask() (task *Task, waitTime time.Duration) { func GetDeferredTask() (task *Task, waitTime time.Duration) {
currentTime := time.Now() currentTime := time.Now()
if deferredTasks.Empty() { if deferredTasks.Len() == 0 {
log("task", 5, "deferred task list is empty") log("task", 5, "deferred task list is empty")
return nil, 0 return nil, 0
} }
// check if a deferred task's endpoint is OK to fetch - minWaitTime := time.Time{}
// sometimes, a task is OK to fetch but the endpoint was delayed by something else
it := deferredTasks.IteratorAt(deferredTasks.Left())
for {
k, v := it.Key().(time.Time), it.Value().(*Task)
if k.After(currentTime) {
log("task", 5, "deferred tasks cannot yet be processed at this time")
return nil, k.Sub(currentTime)
}
if k.Before(v.deferUntil) { for e := deferredTasks.Front(); e != nil; e = e.Next() {
log("task", 5, "deferred task was re-deferred: removing its previous definition") dt := e.Value.(*Task)
defer deferredTasks.Remove(k) if minWaitTime.IsZero() || (dt.deferUntil.Before(minWaitTime) && dt.deferUntil.After(currentTime)) {
it.Next() minWaitTime = dt.deferUntil
continue
} }
if !v.e.delayUntil.IsZero() && v.e.delayUntil.After(currentTime) { if dt.deferUntil.Before(currentTime) && (dt.e.delayUntil.IsZero() || dt.e.delayUntil.Before(currentTime)) {
// skip this task: deferred task is OK, but its endpoint is delayed deferredTasks.Remove(e)
it.Next() return dt, 0
continue
} }
defer deferredTasks.Remove(k)
return v, 0
} }
log("task", 5, "deferred tasks are OK for processing but their endpoints cannot yet be processed at this time") return nil, minWaitTime.Sub(currentTime)
return nil, 0
} }
// FetchTaskComponents returns all components needed to build a Task. // FetchTaskComponents returns all components needed to build a Task.
@ -240,46 +217,25 @@ func CreateTask() (task *Task, delay time.Duration) {
task, delayDeferred := GetDeferredTask() task, delayDeferred := GetDeferredTask()
if task != nil { if task != nil {
log("task", 4, "new task (deferred): %v", task) log("task", 4, "new task (deferred): %v", task)
task.conn = nil
task.good = false
return task, 0 return task, 0
} }
ep, login, password, delaySource := FetchTaskComponents() ep, login, password, delayEndpoint := FetchTaskComponents()
if ep == nil { if ep == nil {
if delayDeferred == 0 && delaySource == 0 { if delayDeferred == 0 && delayEndpoint == 0 {
log("task", 4, "cannot build task, no endpoint") log("task", 4, "cannot build task, no endpoint")
return nil, 0 return nil, 0
} else if delayDeferred > delaySource || delayDeferred == 0 { } else if delayDeferred > delayEndpoint || delayDeferred == 0 {
log("task", 4, "delaying task creation (by source delay) for %v", delaySource) log("task", 4, "delaying task creation (by endpoint delay) for %v", delayEndpoint)
return nil, delaySource return nil, delayEndpoint
} else { } else {
log("task", 4, "delaying task creation (by deferred delay) for %v", delayDeferred) log("task", 4, "delaying task creation (by deferred delay) for %v", delayDeferred)
return nil, delayDeferred return nil, delayDeferred
} }
} }
t := Task{} t := Task{e: ep, login: login, password: password}
t.e = ep
t.login = login
t.password = password
t.good = false
log("task", 4, "new task: %v", &t) log("task", 4, "new task: %v", &t)
return &t, 0 return &t, 0
} }
func init() {
deferredTasks = rbt.NewWith(rbtUtils.TimeComparator)
CfgRegister("threads", 3, "how many threads to use")
CfgRegister("thread-delay-ms", 10, "separate threads at startup for this amount of ms")
CfgRegister("connect-timeout-ms", 3000, "")
CfgRegister("read-timeout-ms", 2000, "")
// using a very high limit for now, but this should actually be set to -1
CfgRegister("task-max-deferrals", 30000, "how many deferrals are allowed for a single task. -1 to disable")
CfgRegisterAlias("t", "threads")
}

@ -1,112 +1,90 @@
package main package main
// thread.go: handling of worker threads
import ( import (
"net"
"sync" "sync"
"time" "time"
) )
// threadWork processes a single work item for a thread. // maxSafeThreads is a safeguard to prevent creation of too many threads at once.
func threadWork(dialer *net.Dialer) bool { const maxSafeThreads = 5000
readTimeout := CfgGetDurationMS("read-timeout-ms")
// ThreadService creates, starts up and waits for all threads.
func ThreadService() {
numThreads := getParamInt("threads")
failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads)
log("thread", 0, "initializing %v threads", numThreads)
c := make(chan bool)
var wg sync.WaitGroup
for i := 1; i <= numThreads; i++ {
wg.Add(1)
go threadEntryPoint(c, i, &wg)
}
threadDelay := getParamDurationMS("thread-delay-ms")
log("thread", 0, "starting %v threads", numThreads)
for i := 1; i <= numThreads; i++ {
c <- true
if threadDelay > 0 {
time.Sleep(threadDelay)
}
}
log("thread", 1, "waiting for threads")
wg.Wait()
log("thread", 1, "finished waiting for threads")
}
// threadEntryPoint is the main entrypoint for a work thread.
func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) {
<-c
log("thread", 3, "starting loop for thread %v", threadIdx)
for threadWork() {
}
log("thread", 3, "exiting thread %v", threadIdx)
wg.Done()
}
// threadWork processes a single work item for a thread.
func threadWork() bool {
task, delay := CreateTask() task, delay := CreateTask()
if task == nil { if task == nil {
if delay > 0 { if delay > 0 {
log("thread", 3, "no endpoints available, sleeping for %v", delay) log("thread", 3, "no active endpoints available, sleeping for %v", delay)
time.Sleep(delay) time.Sleep(delay)
return true return true
} else { } else {
log("thread", 3, "no endpoints available, stopping thread loop") log("thread", 3, "no endpoints available (active and deferred), stopping thread loop")
return false return false
} }
} }
conn, err := dialer.Dial("tcp", task.e.String()) conn, err := NewConnection(task.e)
if err != nil { if err != nil {
task.Event(TE_NoResponse) task.EventWithParm(TE_NoResponse, err)
log("thread", 2, "cannot connect to \"%v\": %v", task.e, err.Error())
return true return true
} }
defer conn.Close()
task.conn = conn
task.Event(TN_Connected) task.Event(TN_Connected)
conn.SetReadDeadline(time.Now().Add(readTimeout)) // should be just before Send() call... log("thread", 2, "trying %v", task)
log("thread", 2, "trying %v:%v on \"%v\"", task.login, task.password, task.e)
// TODO: multiple services (currently just WinBox)
res, err := TryLogin(task, conn) res, err := TryLogin(task, conn)
if err != nil { if err != nil {
task.Event(TE_ProtocolError) task.Event(TE_ProtocolError)
} else { } else {
if res && err == nil { if res && err == nil {
task.EventWithParm(TE_Good, task.login) task.Event(TE_Good)
} else { } else {
task.EventWithParm(TE_Bad, task.login) task.Event(TE_Bad)
} }
} }
return true return true
} }
// threadLoop calls threadWork in a loop, until the endpoints are exhausted,
// a pause/stop signal has been raised, or an exception has occurred in threadWork.
func threadLoop(dialer *net.Dialer) {
for threadWork(dialer) {
// TODO: pause/stop signal
// TODO: exception handling
}
}
// threadEntryPoint is the main entrypoint for a work thread.
func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) {
<-c
log("thread", 3, "starting loop for thread %v", threadIdx)
connectTimeout := time.Duration(CfgGetInt("connect-timeout-ms")) * time.Millisecond
dialer := net.Dialer{Timeout: connectTimeout, KeepAlive: -1}
threadLoop(&dialer)
log("thread", 3, "exiting thread %v", threadIdx)
wg.Done()
}
// InitializeThreads creates and starts up all threads.
func InitializeThreads() *sync.WaitGroup {
numThreads := CfgGetInt("threads")
failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads)
log("thread", 0, "initializing %v threads", numThreads)
c := make(chan bool)
var wg sync.WaitGroup
for i := 1; i <= numThreads; i++ {
wg.Add(1)
go threadEntryPoint(c, i, &wg)
}
threadDelay := CfgGetDurationMS("thread-delay-ms")
log("thread", 0, "starting %v threads", numThreads)
for i := 1; i <= numThreads; i++ {
c <- true
if threadDelay > 0 {
time.Sleep(threadDelay)
}
}
log("thread", 0, "started")
return &wg
}
// WaitForThreads enters a wait state and keeps it until
// all threads have exited.
func WaitForThreads(wg *sync.WaitGroup) {
log("thread", 1, "waiting for threads")
wg.Wait()
log("thread", 1, "finished waiting for threads")
}

@ -3,21 +3,21 @@ package main
import ( import (
"bytes" "bytes"
"errors" "errors"
"net"
) )
type Winbox struct { type Winbox struct {
task *Task task *Task
conn net.Conn conn *Connection
stage int stage int
user, pass string user, pass string
w *WCurve w *WCurve
sa, xwa, xwb, j, z, secret, clientCC, serverCC, i, msg, resp []byte sa, xwa, xwb, j, z, secret, i []byte
clientCC, serverCC, msg, resp []byte
xwaParity, xwbParity bool xwaParity, xwbParity bool
} }
func NewWinbox(task *Task, conn net.Conn) *Winbox { func NewWinbox(task *Task, conn *Connection) *Winbox {
winbox := Winbox{} winbox := Winbox{}
winbox.task = task winbox.task = task
winbox.conn = conn winbox.conn = conn
@ -138,7 +138,7 @@ func (winbox *Winbox) confirmation() error {
func (winbox *Winbox) sendAndRecv() error { func (winbox *Winbox) sendAndRecv() error {
if len(winbox.msg) > 0 && winbox.conn != nil { if len(winbox.msg) > 0 && winbox.conn != nil {
_, err := winbox.conn.Write(winbox.msg) err := winbox.conn.Send(winbox.msg)
winbox.msg = []byte{} winbox.msg = []byte{}
if err != nil { if err != nil {
@ -146,14 +146,11 @@ func (winbox *Winbox) sendAndRecv() error {
return err return err
} }
winbox.resp = make([]byte, 1024) winbox.resp, err = winbox.conn.Recv()
n, err := winbox.conn.Read(winbox.resp)
if err != nil { if err != nil {
log("winbox", 1, "failed to recv: %v", err.Error()) log("winbox", 1, "failed to recv: %v", err.Error())
return err return err
} }
winbox.resp = winbox.resp[:n]
} }
return nil return nil

Loading…
Cancel
Save