From 68fee6bbb7e21d0497b3899f5116858f82cb66f1 Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 28 Nov 2022 22:17:19 +0300 Subject: [PATCH] feat: partial refactoring --- config.go | 163 ++++++++++++++++--------------- conn.go | 5 +- crypt.go | 4 +- curve.go | 2 +- endpoint.go | 253 +++++++++++++++++++++++++++-------------------- go.mod | 4 +- go.sum | 2 - source.go | 275 ++++++++++++++++++++++++++-------------------------- winbox.go | 2 +- 9 files changed, 378 insertions(+), 332 deletions(-) delete mode 100644 go.sum diff --git a/config.go b/config.go index 0f3d20e..b2e7bd6 100644 --- a/config.go +++ b/config.go @@ -8,29 +8,34 @@ import ( "time" ) -// ConfigParameterOptions represents additional options for a ConfigParameter. -type ConfigParameterOptions struct { +// configParameterOptions represents additional options for a configParameter. +type configParameterOptions struct { sw, hidden, command bool callback func() } -// ConfigParameter represents a single configuration parameter. -type ConfigParameter struct { - name string // duplicated in configMap, but also saved here for convenience - value, def interface{} // value and default value - description string // description for this parameter - parsed bool // true if it was successfully parsed from commandline - opts ConfigParameterOptions +// configParameter represents a single configuration parameter. +type configParameter struct { + name string // duplicated in configMap, but also saved here for convenience + value, def any // value and default value + description string // description for this parameter + parsed bool // true if it was successfully parsed from commandline + configParameterOptions } -var configMap map[string]ConfigParameter = make(map[string]ConfigParameter, 16) -var configAliasMap map[string]string = make(map[string]string, 4) +type configParameterTypeUnion = interface { + bool | int | uint | float64 | string | []int | []uint | []float64 | []string | []bool +} + +var configMap = map[string]configParameter{} +var configAliasMap = map[string]string{} -var ConfigParsingFinished bool = false +var configParsingFinished = false +// --- // registration -func genericRegister(name string, def interface{}, description string, opts ConfigParameterOptions) { +func registerConfigParameter[T configParameterTypeUnion](name string, def T, description string, opts configParameterOptions) { name = strings.ToLower(name) _, ok := configMap[name] @@ -41,41 +46,35 @@ func genericRegister(name string, def interface{}, description string, opts Conf failIf(opts.command && opts.callback == nil, "\"%v\" is defined as a command but callback is missing", name) - p := ConfigParameter{} - p.name = name - p.value = def - p.def = def - p.description = description - p.opts = opts - - configMap[name] = p + configMap[name] = configParameter{name: name, value: def, def: def, description: description, + configParameterOptions: opts} } -func CfgRegister(name string, def interface{}, description string) { - genericRegister(name, def, description, ConfigParameterOptions{}) +func registerParam[T configParameterTypeUnion](name string, def T, description string) { + registerConfigParameter(name, def, description, configParameterOptions{}) } -func CfgRegisterEx(name string, def interface{}, description string, options ConfigParameterOptions) { - genericRegister(name, def, description, options) +func registerParamEx[T configParameterTypeUnion](name string, def T, description string, options configParameterOptions) { + registerConfigParameter(name, def, description, options) } -func CfgRegisterCallback(name string, def interface{}, description string, callback func()) { - genericRegister(name, def, description, ConfigParameterOptions{callback: callback}) +func registerParamHidden[T configParameterTypeUnion](name string, def T) { + registerConfigParameter(name, def, "", configParameterOptions{hidden: true}) } -func CfgRegisterCommand(name string, description string, callback func()) { - genericRegister(name, false, description, ConfigParameterOptions{command: true, callback: callback}) +func registerParamWithCallback[T configParameterTypeUnion](name string, def T, description string, callback func()) { + registerConfigParameter(name, def, description, configParameterOptions{callback: callback}) } -func CfgRegisterSwitch(name string, description string) { - genericRegister(name, false, description, ConfigParameterOptions{sw: true}) +func registerCommand(name string, description string, callback func()) { + registerConfigParameter(name, false, description, configParameterOptions{command: true, callback: callback}) } -func CfgRegisterHidden(name string, def interface{}) { - genericRegister(name, def, "", ConfigParameterOptions{hidden: true}) +func registerSwitch(name string, description string) { + registerConfigParameter(name, false, description, configParameterOptions{sw: true}) } -func CfgRegisterAlias(alias, target string) { +func registerAlias(alias, target string) { alias, target = strings.ToLower(alias), strings.ToLower(target) _, ok := configAliasMap[alias] @@ -90,52 +89,59 @@ func CfgRegisterAlias(alias, target string) { configAliasMap[alias] = target } +// --- // acquisition -func CfgGet(name string) interface{} { +func getParamGeneric(name string) any { name = strings.ToLower(name) + parm, ok := configMap[name] failIf(!ok, "unknown config parameter: \"%v\"", name) + failIf(parm.command, "config parameter \"%v\" is a command", name) - if parm.opts.sw { + if parm.sw { return parm.parsed // switches always return true if they were parsed - } else if parm.parsed || parm.opts.hidden { + } else if parm.parsed || parm.hidden { return parm.value // parsed and hidden parms return their current value } else { return parm.def // otherwise, use default value } } -func CfgGetInt(name string) int { - return CfgGet(name).(int) +func getParam[T configParameterTypeUnion](name string) T { + return getParamGeneric(name).(T) } -func CfgGetFloat(name string) float64 { - return CfgGet(name).(float64) +func getParamInt(name string) int { + return getParam[int](name) } -func CfgGetIntSlice(name string) []int { - return CfgGet(name).([]int) +func getParamFloat(name string) float64 { + return getParam[float64](name) } -func CfgGetBool(name string) bool { - return CfgGet(name).(bool) +func getParamIntSlice(name string) []int { + return getParam[[]int](name) } -func CfgGetSwitch(name string) bool { - return CfgGet(name).(bool) +func getParamBool(name string) bool { + return getParam[bool](name) } -func CfgGetString(name string) string { - return CfgGet(name).(string) +func getParamSwitch(name string) bool { + return getParamBool(name) } -func CfgGetStringSlice(name string) []string { - return CfgGet(name).([]string) +func getParamString(name string) string { + return getParam[string](name) } -func CfgGetDurationMS(name string) time.Duration { - tm := CfgGet(name).(int) +func getParamStringSlice(name string) []string { + return getParam[[]string](name) +} + +func getParamDurationMS(name string) time.Duration { + tm := getParam[int](name) failIf(tm < -1, "\"%v\" can only be set to -1 or a positive value", name) if tm == -1 { tm = 0 @@ -144,31 +150,36 @@ func CfgGetDurationMS(name string) time.Duration { return time.Duration(tm) * time.Millisecond } +// --- // setting -func CfgSet(name string, value interface{}) { +func setParam(name string, value any) { name = strings.ToLower(name) + parm, ok := configMap[name] failIf(!ok, "unknown config parameter: \"%v\"", name) - failIf(!parm.opts.hidden, "tried to set \"%v\", but it is not a hidden parameter", name) + failIf(parm.hidden, "config parameter \"%v\" is hidden and cannot be set", name) + failIf(parm.command, "config parameter \"%v\" is a command and cannot be set", name) + failIf(parm.sw && !value.(bool), "config parameter \"%v\" is a switch and only accepts boolean arguments", name) parm.value = value - if parm.opts.callback != nil { - parm.opts.callback() + if parm.callback != nil { + parm.callback() } configMap[name] = parm } +// --- // parsing -// getCmdlineParm returns a trimmed commandline parameter with specified index. +// getCmdlineParm retrieves a commandline parameter with index i. func getCmdlineParm(i int) string { return strings.TrimSpace(os.Args[i]) } -// isSlice checks if a ConfigParameter value is a slice. -func (parm *ConfigParameter) isSlice() bool { +// isSlice checks if a configParameter value is a slice. +func (parm *configParameter) isSlice() bool { switch parm.value.(type) { case []int, []uint, []string: return true @@ -177,8 +188,8 @@ func (parm *ConfigParameter) isSlice() bool { } } -// writeParmValue saves raw commandline value into a ConfigParameter. -func (parm *ConfigParameter) writeParmValue(value string) { +// writeParmValue saves raw commandline value into a configParameter. +func (parm *configParameter) writeParmValue(value string) { var err error switch parm.value.(type) { @@ -214,14 +225,14 @@ func (parm *ConfigParameter) writeParmValue(value string) { } } -// finalizeParm marks a ConfigParameter as parsed, adds it to a global config map +// finalize marks a configParameter as parsed, adds it to a global config map // and calls its callback, if one is present. -func (parm *ConfigParameter) finalizeParm() { +func (parm *configParameter) finalize() { parm.parsed = true configMap[parm.name] = *parm - if parm.opts.callback != nil { - parm.opts.callback() + if parm.callback != nil { + parm.callback() } log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value) @@ -255,11 +266,11 @@ func parseAppConfig() { log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name) } - failIf(parm.opts.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) + failIf(parm.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name) - if !parm.opts.command { - if parm.opts.sw { + if !parm.command { + if parm.sw { parm.writeParmValue("true") } else { i++ @@ -267,12 +278,12 @@ func parseAppConfig() { } } - finalizeParm(&parm) + parm.finalize() totalFinalized++ } log("cfg", 1, "parsed %v commandline parameters", totalFinalized) - ConfigParsingFinished = true + configParsingFinished = true } func showHelp() { @@ -287,7 +298,7 @@ func showHelp() { for _, parmName := range parms { parm := configMap[parmName] - if parm.opts.hidden { + if parm.hidden { continue } @@ -319,7 +330,7 @@ func showHelp() { description = " " + parm.description } - if parm.opts.command || parm.opts.sw { + if parm.command || parm.sw { log("", 0, "%s\n%s", header, description) } else { log("", 0, "%s\n%s\n default: %v", header, description, parm.value) @@ -337,7 +348,7 @@ func showHelp() { } func init() { - CfgRegisterCommand("help", "show program usage", showHelp) - CfgRegisterAlias("?", "help") - CfgRegisterAlias("h", "help") + registerCommand("help", "show program usage", showHelp) + registerAlias("?", "help") + registerAlias("h", "help") } diff --git a/conn.go b/conn.go index c7d9352..e3ea937 100644 --- a/conn.go +++ b/conn.go @@ -17,8 +17,8 @@ type Connection struct { // NewConnection creates a Connection object. func NewConnection() *Connection { conn := Connection{} - conn.connectTimeout = CfgGetDurationMS("connect-timeout-ms") - conn.readTimeout = CfgGetDurationMS("read-timeout-ms") + conn.connectTimeout = getParamDurationMS("connect-timeout-ms") + conn.readTimeout = getParamDurationMS("read-timeout-ms") conn.protocol = "tcp" return &conn } @@ -47,5 +47,4 @@ func (conn *Connection) SetReadTimeout(timeout time.Duration) { // Send writes data to a Connection. func (conn *Connection) Send(data []byte) { conn.socket.SetReadDeadline(time.Now().Add(conn.readTimeout)) - } diff --git a/crypt.go b/crypt.go index 1f95d7b..a3fe906 100644 --- a/crypt.go +++ b/crypt.go @@ -84,7 +84,7 @@ func genRandomBytes(n int) ([]byte, error) { b := make([]byte, n) var err error - if CfgGetSwitch("crypt-predictable-rng") { + if getParamSwitch("crypt-predictable-rng") { _, err = mathRand.Read(b) } else { _, err = cryptoRand.Read(b) @@ -98,6 +98,6 @@ func genRandomBytes(n int) ([]byte, error) { } func init() { - CfgRegisterSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng") + registerSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng") mathRand.Seed(300) } diff --git a/curve.go b/curve.go index ecac5b6..732971f 100644 --- a/curve.go +++ b/curve.go @@ -7,7 +7,7 @@ type WCurve struct { conversion bigint } -func NewWCurve() *WCurve { +func newWCurve() *WCurve { curve := WCurve{} curve.p = NewBigintFromString("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed", 16) diff --git a/endpoint.go b/endpoint.go index 9a16963..81f6250 100644 --- a/endpoint.go +++ b/endpoint.go @@ -16,29 +16,30 @@ type Address struct { v6 bool } -type Endpoint struct { - addr Address +type EndpointState int - loginPos SourcePos - passwordPos SourcePos +const ( + ES_Normal EndpointState = iota + ES_Delayed + ES_Deleted +) - delayUntil time.Time +// An Endpoint represents a remote target and stores its persistent data between multiple connections. +type Endpoint struct { + addr Address // IP address of an endpoint - normalList *list.Element - delayedList *list.Element + loginPos, passwordPos SourcePos // login/password cursors + listElement *list.Element // position in list - goodConn int - badConn int - consecutiveGoodConn int - consecutiveBadConn int - protoErrors int - consecutiveProtoErrors int - readErrors int - consecutiveReadErrors int + state EndpointState // which state an endpoint is in + delayUntil time.Time // when this endpoint can be used again - mutex sync.Mutex + // endpoint stats + goodConn, badConn, protoErrors, readErrors int + consecutiveGoodConn, consecutiveBadConn, consecutiveProtoErrors, + consecutiveReadErrors int - deleted bool // set to TRUE to mark this endpoint as deleted + mutex sync.Mutex // sync primitive // unused, for now rtt float32 @@ -48,12 +49,35 @@ type Endpoint struct { lastAttemptAt time.Time // same, but for attempts } -var endpoints *list.List // Contains all active endpoints -var delayedEndpoints *list.List // Contains endpoints that got delayed +var endpoints *list.List // Contains all active and ready endpoints +var delayedEndpoints *list.List // Contains endpoints that are active, but not ready // A mutex for synchronizing Endpoint collections. var globalEndpointMutex sync.Mutex +func (state EndpointState) String() string { + switch state { + case ES_Normal: + return "normal" + case ES_Delayed: + return "delayed" + case ES_Deleted: + return "deleted" + } + + return "unknown" +} + +func (state EndpointState) GetList() *list.List { + switch state { + case ES_Normal: + return endpoints + case ES_Delayed: + return delayedEndpoints + } + + return nil +} // String transforms an Endpoint to a string representation compatible with Dialer interface. func (e *Endpoint) String() string { @@ -64,6 +88,9 @@ func (e *Endpoint) String() string { } } +func (e *Endpoint) GetList() *list.List { + return e.state.GetList() +} // Delete deletes an endpoint from global storage. // This method assumes that Endpoint's mutex was already taken. @@ -71,51 +98,55 @@ func (e *Endpoint) Delete() { globalEndpointMutex.Lock() defer globalEndpointMutex.Unlock() - e.delayUntil = time.Time{} - - if e.delayedList != nil { - log("ep", 3, "deleting delayed endpoint \"%v\"", e) - delayedEndpoints.Remove(e.delayedList) - e.delayedList = nil - } - - if e.normalList != nil { + list := e.GetList() + if list != nil { log("ep", 3, "deleting endpoint \"%v\"", e) - endpoints.Remove(e.normalList) - e.normalList = nil + list.Remove(e.listElement) + e.listElement = nil } - e.deleted = true + e.delayUntil = time.Time{} + e.state = ES_Deleted } -// Delay marks an Endpoint as "delayed" with the specified time duration -// and causes it to move to the delayed queue. -// This method assumes that Endpoint's mutex was already taken. -func (e *Endpoint) Delay(addTime time.Duration) { - if e.deleted { +// SetState changes an endpoint's state. +func (e *Endpoint) SetState(newState EndpointState) { + if e.state == newState { + log("ep", 5, "ignoring state change for an endpoint \"%v\": already in state \"%v\"", e, e.state) return } - log("ep", 5, "delaying endpoint \"%v\" for %v", e, addTime) - e.delayUntil = time.Now().Add(addTime) - e.MigrateToDelayed() -} + oldList := e.GetList() + newList := newState.GetList() -// MigrateToDelayed moves an Endpoint to a delayed queue. -// Endpoint mutex is assumed to be taken. -func (e *Endpoint) MigrateToDelayed() { - endpointMutex.Lock() - defer endpointMutex.Unlock() + globalEndpointMutex.Lock() + defer globalEndpointMutex.Unlock() - if e.delayedList != nil { - // already in a delayed list - log("ep", 5, "cannot migrate endpoint \"%v\" to delayed list: already in the list", e) + if e.listElement != nil { + oldList.Remove(e.listElement) + } + + if newList == nil { + e.listElement = nil } else { - log("ep", 5, "migrating endpoint \"%v\" to delayed list", e) - e.delayedList = delayedEndpoints.PushBack(e) - if e.normalList != nil { - endpoints.Remove(e.normalList) - e.normalList = nil + e.listElement = newList.PushBack(e) + } +} + +// Delay marks an Endpoint as "delayed" for a certain duration +// and migrates it to the delayed queue. +// This method assumes that Endpoint's mutex was already taken. +func (e *Endpoint) Delay(addTime time.Duration) { + if e.state == ES_Normal { + log("ep", 5, "delaying endpoint \"%v\" for %v", e, addTime) + e.delayUntil = time.Now().Add(addTime) + e.SetState(ES_Delayed) + } else if e.state == ES_Delayed { + // endpoints that are already delayed can have their delay time extended further + + tm := time.Now().Add(addTime) + if e.delayUntil.Before(tm) { + e.delayUntil = tm } } } @@ -170,30 +201,30 @@ func (e *Endpoint) NoResponse() bool { } // 1. always bail after X consecutive bad conns - if e.consecutiveBadConn >= CfgGetInt("max-bad-conn") { + if e.consecutiveBadConn >= getParamInt("max-bad-conn") { log("ep", 3, "deleting \"%v\" due to max-bad-conn", e) e.Delete() return false } // 2. after a good conn, always allow at most X bad conns - if e.goodConn > 0 && e.consecutiveBadConn <= CfgGetInt("max-bad-after-good-conn") { + if e.goodConn > 0 && e.consecutiveBadConn <= getParamInt("max-bad-after-good-conn") { log("ep", 3, "keeping \"%v\" around due to max-bad-after-good-conn", e) - e.Delay(CfgGetDurationMS("no-response-delay-ms")) + e.Delay(getParamDurationMS("no-response-delay-ms")) return true } // 3. always allow at most X bad conns - if e.consecutiveBadConn < CfgGetInt("min-bad-conn") { + if e.consecutiveBadConn < getParamInt("min-bad-conn") { log("ep", 3, "keeping \"%v\" around due to min-bad-conn", e) - e.Delay(CfgGetDurationMS("no-response-delay-ms")) + e.Delay(getParamDurationMS("no-response-delay-ms")) return true } // 4. bad conn/good conn ratio must not be higher than X - if e.goodConn > 0 && (float64(e.badConn)/float64(e.goodConn)) <= CfgGetFloat("conn-ratio") { + if e.goodConn > 0 && (float64(e.badConn)/float64(e.goodConn)) <= getParamFloat("conn-ratio") { log("ep", 3, "keeping \"%v\" around due to conn-ratio", e) - e.Delay(CfgGetDurationMS("no-response-delay-ms")) + e.Delay(getParamDurationMS("no-response-delay-ms")) return true } @@ -213,23 +244,23 @@ func (e *Endpoint) ProtocolError() bool { e.consecutiveProtoErrors++ // 1. always bail after X consecutive protocol errors - if e.consecutiveProtoErrors >= CfgGetInt("max-proto-errors") { + if e.consecutiveProtoErrors >= getParamInt("max-proto-errors") { log("ep", 3, "deleting \"%v\" due to max-proto-errors", e) e.Delete() return false } // 2. always allow at most X consecutive protocol errors - if e.consecutiveProtoErrors < CfgGetInt("min-proto-errors") { + if e.consecutiveProtoErrors < getParamInt("min-proto-errors") { log("ep", 3, "keeping \"%v\" around due to min-proto-errors", e) - e.Delay(CfgGetDurationMS("protocol-error-delay-ms")) + e.Delay(getParamDurationMS("protocol-error-delay-ms")) return true } // 3. bad conn/good conn ratio must not be higher than X - if e.goodConn > 0 && (float64(e.protoErrors)/float64(e.goodConn)) <= CfgGetFloat("proto-error-ratio") { + if e.goodConn > 0 && (float64(e.protoErrors)/float64(e.goodConn)) <= getParamFloat("proto-error-ratio") { log("ep", 3, "keeping \"%v\" around due to proto-error-ratio", e) - e.Delay(CfgGetDurationMS("protocol-error-delay-ms")) + e.Delay(getParamDurationMS("protocol-error-delay-ms")) return true } @@ -257,7 +288,7 @@ func (e *Endpoint) Good(login) { defer e.mutex.Unlock() e.consecutiveProtoErrors = 0 - if !CfgGetSwitch("keep-endpoint-on-good") { + if !getParamSwitch("keep-endpoint-on-good") { e.Delete() } else { e.MigrateToNormal() @@ -337,6 +368,7 @@ func (e *Endpoint) Exhausted() { } // GetDelayedEndpoint retrieves an Endpoint from the delayed list. +// globalEndpointMutex must be already taken. func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) { currentTime := time.Now() @@ -350,7 +382,7 @@ func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) { k, v := it.Key().(time.Time), it.Value().(*Endpoint) if v == nil { - log("ep", 5, "!!! empty delayed endpoint!!!") + panic("delayed endpoint list contains an empty endpoint") return nil, 0 } @@ -381,13 +413,15 @@ func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) { return nil, 0 } -// FetchEndpoint retrieves an endpoint: first, a delayed RB tree is queried, +// FetchEndpoint retrieves an endpoint: first, a delayed list is queried, // then, if nothing is found, a normal list is searched, // and (TODO) if this list is empty or will soon be emptied, // a new batch of endpoints gets created. func FetchEndpoint() (e *Endpoint, waitTime time.Duration) { - endpointMutex.Lock() - defer endpointMutex.Unlock() + globalEndpointMutex.Lock() + defer globalEndpointMutex.Unlock() + + log("ep", 4, "fetching an endpoint") e, waitTime = GetDelayedEndpoint() if e != nil { @@ -413,6 +447,10 @@ func FetchEndpoint() (e *Endpoint, waitTime time.Duration) { return e, 0 } +// --- +// --- +// --- + // Safety feature, to avoid expanding subnets into a huge amount of IPs. const maxNetmaskSize = 22 // expands into /10 for IPv4 @@ -424,7 +462,7 @@ func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int { ep.passwordPos.Reset() ep.listElement = endpoints.PushBack(&ep) - log("ep", 3, "ok registered: %v", &ep) + log("ep", 3, "registered endpoint: %v", &ep) } return len(ports) @@ -439,6 +477,7 @@ func incIP(ip net.IP) { } } +// parseCIDR registers multiple endpoints from a CIDR netmask. func parseCIDR(ip string, ports []int, isIPv6 bool) int { na, nm, err := net.ParseCIDR(ip) if err != nil { @@ -455,7 +494,7 @@ func parseCIDR(ip string, ports []int, isIPv6 bool) int { curHost := 0 maxHost := 1<<(maskBits-mask) - 1 numParsed := 0 - strict := CfgGetSwitch("strict-subnets") + strict := getParamSwitch("strict-subnets") log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1) @@ -463,7 +502,7 @@ func parseCIDR(ip string, ports []int, isIPv6 bool) int { if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 { log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String()) } else { - numParsed += parseIPPorts(expIP.String(), ports, isIPv6) + numParsed += RegisterEndpoint(expIP.String(), ports, isIPv6) } curHost++ } @@ -471,7 +510,8 @@ func parseCIDR(ip string, ports []int, isIPv6 bool) int { return numParsed } -func parseIPPorts(ip string, ports []int, isIPv6 bool) int { +// parseIPOrCIDR expands plain IP or CIDR to multiple endpoints. +func parseIPOrCIDR(ip string, ports []int, isIPv6 bool) int { // ip may be a domain name, a CIDR subnet or an IP address // CIDR subnets must be expanded to plain IPs @@ -485,7 +525,8 @@ func parseIPPorts(ip string, ports []int, isIPv6 bool) int { } } -func extractIPAndPort(str string, skippedIPv6 *int) (ip string, port int, err error) { +// extractIPAndPort extracts all endpoint components. +func extractIPAndPort(str string) (ip string, port int, err error) { var portString string ip, portString, err = net.SplitHostPort(str) if err != nil { @@ -504,8 +545,10 @@ func extractIPAndPort(str string, skippedIPv6 *int) (ip string, port int, err er return ip, port, nil } +// ParseEndpoints takes a string slice of IPs/CIDR subnets and converts it to a list of endpoints. func ParseEndpoints(source []string) { log("ep", 1, "parsing endpoints") + totalIPv6Skipped := 0 numParsed := 0 @@ -514,74 +557,72 @@ func ParseEndpoints(source []string) { // no ":": this is an ipv4/dn without port, // parse it with all known ports - numParsed += parseIPPorts(str, CfgGetIntSlice("port"), false) + numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), false) } else { // either ipv4/dn with port, or ipv6 with/without port isIPv6 := strings.Count(str, ":") > 1 - if isIPv6 && CfgGetSwitch("no-ipv6") { - log("ep", 1, "skipping ipv6 target \"%v\" due to no-ipv6", str) + if isIPv6 && getParamSwitch("no-ipv6") { totalIPv6Skipped++ continue } if !strings.Contains(str, "]:") && strings.Contains(str, "::") { // ipv6 without port - numParsed += parseIPPorts(str, CfgGetIntSlice("port"), true) + numParsed += parseIPOrCIDR(str, getParamIntSlice("port"), true) continue } - ip, port, err := extractIPAndPort(str, &totalIPv6Skipped) + ip, port, err := extractIPAndPort(str) if err != nil { - log("ep", 0, "failed to extract ip/port for \"%v\": %v", str, err.Error()) + log("ep", 0, "failed to extract ip/port for \"%v\": %v, ignoring endpoint", str, err.Error()) continue } ports := []int{port} // append all default ports - if CfgGetSwitch("append-default-ports") { - for _, port2 := range CfgGetIntSlice("port") { + if getParamSwitch("append-default-ports") { + for _, port2 := range getParamIntSlice("port") { if port != port2 { ports = append(ports, port2) } } } - numParsed += parseIPPorts(ip, ports, isIPv6) + numParsed += parseIPOrCIDR(ip, ports, isIPv6) } } - logIf(totalIPv6Skipped > 0, "ep", 0, "skipping %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped) - log("ep", 1, "finished parsing endpoints: got %v, total %v", numParsed, endpoints.Len()) + logIf(totalIPv6Skipped > 0, "ep", 0, "skipped %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped) + log("ep", 1, "finished parsing endpoints: parsed %v out of total %v", numParsed, endpoints.Len()) } func init() { endpoints = list.New() delayedEndpoints = list.New() - CfgRegister("port", []int{8291}, "one or more default ports") - CfgRegister("max-aps", 5, "maximum number of attempts per second for an endpoint") - CfgRegisterSwitch("no-ipv6", "skip IPv6 entries") - CfgRegisterSwitch("append-default-ports", "always append default ports even for targets in host:port format") - CfgRegisterSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets") - - CfgRegisterSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found") + registerParam("port", []int{8291}, "one or more default ports") + registerParam("max-aps", 5, "maximum number of attempts per second for an endpoint") + registerSwitch("no-ipv6", "skip IPv6 entries") + registerSwitch("append-default-ports", "always append default ports even for targets in host:port format") + registerSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets") - CfgRegister("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value") - CfgRegister("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection") - CfgRegister("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections") - CfgRegister("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections") + registerSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found") - CfgRegister("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value") - CfgRegister("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors") - CfgRegister("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors") + registerParam("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value") + registerParam("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection") + registerParam("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections") + registerParam("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections") - CfgRegister("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value") - CfgRegister("max-read-errors", 20, "always remove endpoint after this many consecutive read errors") - CfgRegister("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors") + registerParam("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value") + registerParam("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors") + registerParam("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors") - CfgRegister("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond") - CfgRegister("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error") - CfgRegister("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error") + registerParam("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value") + registerParam("max-read-errors", 20, "always remove endpoint after this many consecutive read errors") + registerParam("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors") + registerParam("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond") + registerParam("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error") + registerParam("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error") } diff --git a/go.mod b/go.mod index 9d22a28..e946f94 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module mtbf -go 1.18 - -require github.com/emirpasic/gods v1.18.1 +go 1.18 \ No newline at end of file diff --git a/go.sum b/go.sum deleted file mode 100644 index b5ad666..0000000 --- a/go.sum +++ /dev/null @@ -1,2 +0,0 @@ -github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= -github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= diff --git a/source.go b/source.go index 9adec5d..e186b62 100644 --- a/source.go +++ b/source.go @@ -9,17 +9,19 @@ import ( ) type Source struct { - name, plainParmName, filesParmName string + name string // name of this source - validator func(item string) (string, error) + plain []string // sources from commandline + contents []string // sources from files - plain []string - files []*os.File - fileNames []string + files []*os.File // file pointers + fileNames []string // file names + plainParmName string // name of "plain" commandline parameter + filesParmName string // name of "files" commandline parameter - contents []string + transform func(item string) (string, error) // optional transformation function - fetchMutex sync.Mutex + fetchMutex sync.Mutex // sync mutex } // both -1: exhausted @@ -29,51 +31,108 @@ type SourcePos struct { contentIdx int } +// String converts a SourcePos to its string representation. func (pos *SourcePos) String() string { return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx) } +// Exhausted checks if a SourcePos can no longer produce any sources. func (pos *SourcePos) Exhausted() bool { return pos.plainIdx == -1 && pos.contentIdx == -1 } -func ipValidator(item string) (res string, err error) { - return item, nil +// Reset moves a SourcePos to its starting position. +func (pos *SourcePos) Reset() { + pos.plainIdx = 0 + pos.contentIdx = 0 + log("src", 3, "resetting source pos") } -func passwordValidator(item string) (res string, err error) { - if CfgGetSwitch("no-password-trim") { +// passwordTransform is a transformation function for a password. +func passwordTransform(item string) (res string, err error) { + if getParamSwitch("no-password-trim") { return item, nil } else { return strings.TrimSpace(item), nil } } +// SrcIP, SrcLogin and SrcPassword represent different sources. var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"} var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"} var SrcPassword Source = Source{name: "password", plainParmName: "password", - filesParmName: "password-file", validator: passwordValidator} + filesParmName: "password-file", transform: passwordTransform} + +// String converts a Source to its string representation. +func (src *Source) String() string { + return src.name +} -func (src *Source) validate(item string) (res string, err error) { - if src.validator != nil { - res, err := src.validator(item) +// ValidateAndTransformItem attempts to validate a source item +// and performs transformations, if any. +func (src *Source) ValidateAndTransformItem(item string) (res string, err error) { + if src.transform != nil { + res, err := src.transform(item) if err != nil { - log("src", 1, "error validating %v \"%v\": %v", src.name, item, err.Error()) + log("src", 1, "error validating %v \"%v\": %v", src, item, err.Error()) + res = "" } + return res, err } else { return item, nil } } -func (src *Source) parsePlain() { - if src.plain == nil { - src.plain = []string{} +// FetchFromSlice retrieves an item from a string slice and optionally increments its current position. +func (src *Source) FetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) { + if *idx == -1 { // exhausted + log("src", 5, "fetch %v from %v: idx is -1, return empty", src, name) + return "", true } - for _, plain := range CfgGetStringSlice(src.plainParmName) { - var err error - plain, err = src.validate(plain) + if *idx >= len(slice) { + log("src", 5, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src, name, *idx, len(slice)) + *idx = -1 + return "", true + } + + res = slice[*idx] + log("src", 5, "fetch %v from %v: ok, got %v at idx %v", src, name, res, *idx) + + if inc { + *idx = *idx + 1 + log("src", 5, "fetch %v from %v: incrementing idx to %v", src, name, *idx) + } + + return res, false +} + +// FetchOne retrieves an item from a Source with a specified SourcePos. +func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) { + src.fetchMutex.Lock() + defer src.fetchMutex.Unlock() + + if getParamSwitch("file-contents-first") { + res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc) + if empty { + res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc) + } + } else { + res, empty = src.FetchFromSlice("plain", &pos.plainIdx, src.plain, inc) + if empty { + res, empty = src.FetchFromSlice("contents", &pos.contentIdx, src.contents, inc) + } + } + + logIf(empty, "src", 2, "exhausted source %v for pos %v", src, pos.String()) + return res, empty +} + +// ParsePlain parses commandline parameters for a Source. +func (src *Source) ParsePlain() { + for _, plain := range getParamStringSlice(src.plainParmName) { + plain, err := src.ValidateAndTransformItem(plain) if err != nil { continue } @@ -82,20 +141,14 @@ func (src *Source) parsePlain() { } if len(src.plain) > 0 { - log("src", 1, "parsed %v %v items", len(src.plain), src.name) + log("src", 1, "parsed %v %v items", len(src.plain), src) } } -func (src *Source) openFiles() { - if src.files == nil { - src.files = []*os.File{} - } - - if src.fileNames == nil { - src.fileNames = []string{} - } +// OpenFiles opens all files for a Source. +func (src *Source) OpenFiles() { + fileNames := getParamStringSlice(src.filesParmName) - fileNames := CfgGetStringSlice(src.filesParmName) for _, fileName := range fileNames { f, err := os.Open(fileName) if err != nil { @@ -107,16 +160,12 @@ func (src *Source) openFiles() { } if len(src.files) > 0 { - log("src", 1, "opened %v %v files", len(src.files), src.name) + log("src", 1, "opened %v %v files", len(src.files), src) } } -// this parses all source files -func (src *Source) parseFiles() { - if src.contents == nil { - src.contents = []string{} - } - +// ParseFiles parses all files for a Source. +func (src *Source) ParseFiles() { for i, file := range src.files { fileName := src.fileNames[i] log("src", 1, "parsing %v", fileName) @@ -129,7 +178,7 @@ func (src *Source) parseFiles() { continue } - value, err := src.validate(text) + value, err := src.ValidateAndTransformItem(text) if err != nil { continue } @@ -140,39 +189,41 @@ func (src *Source) parseFiles() { scannerErr := scanner.Err() failIf(scannerErr != nil, "error reading source file \"%v\": %v", fileName, scannerErr) - log("src", 1, "ok: parsed %v, got %v contents, %v total", fileName, thisTotal, len(src.contents)) + log("src", 1, "ok: parsed \"%v\", got %v contents, %v total", fileName, thisTotal, len(src.contents)) } } -func (src *Source) failIfEmpty() { - failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src.name, src.plainParmName, src.filesParmName) +// FailIfEmpty throws an exception if a source is empty. +func (src *Source) FailIfEmpty() { + failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src, src.plainParmName, src.filesParmName) } -func (src *Source) reportLoaded() { - log("src", 0, "loaded %vs: %v items from commandline and %v items from files", src.name, len(src.plain), len(src.contents)) +// ReportLoaded prints a console message about the number of loaded items for a Source. +func (src *Source) ReportLoaded() { + log("src", 0, "loaded %vs: %v items from commandline and %v items from files", src, len(src.plain), len(src.contents)) } -func (src *Source) load(wg *sync.WaitGroup) { +// LoadSource fills a Source with data (from commandline and from files). +func (src *Source) LoadSource(wg *sync.WaitGroup) { if wg != nil { defer wg.Done() } - if src.name == "password" && CfgGetSwitch("add-empty-password") { - if src.plain == nil { - src.plain = []string{} - } + if src.name == "password" && getParamSwitch("add-empty-password") { src.plain = append(src.plain, "") } - src.parsePlain() - src.openFiles() - defer src.closeFiles() + src.ParsePlain() + + src.OpenFiles() + defer src.CloseSource() - src.parseFiles() - src.failIfEmpty() + src.ParseFiles() + src.FailIfEmpty() } -func (src *Source) closeFiles() { +// CloseSource closes all files for a Source. +func (src *Source) CloseSource() { l := len(src.files) for _, file := range src.files { if file != nil { @@ -180,109 +231,57 @@ func (src *Source) closeFiles() { } } - src.files = []*os.File{} - src.fileNames = []string{} - log("src", 1, "closed all %v %v files", l, src.name) + src.files = nil + src.fileNames = nil + log("src", 1, "closed all %v %v files", l, src) } -func (src *Source) fetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) { - if *idx == -1 { - // exhausted - log("src", 3, "fetch %v from %v: idx is -1, return empty", src.name, name) - return "", true - } - - if *idx >= len(slice) { - log("src", 3, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src.name, name, *idx, len(slice)) - *idx = -1 - return "", true - } - - res = slice[*idx] - log("src", 3, "fetch %v from %v: ok, got %v at idx %v", src.name, name, res, *idx) - - if inc { - *idx = *idx + 1 - log("src", 3, "fetch %v from %v: incrementing idx to %v", src.name, name, *idx) - } - - return res, false -} - -// retrieves an item from source -// increments pos -func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) { - src.fetchMutex.Lock() - defer src.fetchMutex.Unlock() - - if CfgGetSwitch("file-contents-first") { - res, empty = src.fetchFromSlice("contents", &pos.contentIdx, src.contents, inc) - if empty { - res, empty = src.fetchFromSlice("plain", &pos.plainIdx, src.plain, inc) - } - } else { - res, empty = src.fetchFromSlice("plain", &pos.plainIdx, src.plain, inc) - if empty { - res, empty = src.fetchFromSlice("contents", &pos.contentIdx, src.contents, inc) - } - } - - logIf(empty, "src", 2, "exhausted source %v for pos %v", src.name, pos.String()) - return res, empty -} - -func (pos *SourcePos) Reset() { - pos.plainIdx = 0 - pos.contentIdx = 0 - log("src", 3, "resetting source pos") -} +// --- +// --- +// --- +// LoadSources loads contents for all sources, both from commandline and from files. func LoadSources() { log("src", 1, "loading sources") var wg sync.WaitGroup wg.Add(3) - go SrcIP.load(&wg) - go SrcLogin.load(&wg) - go SrcPassword.load(&wg) + go SrcIP.LoadSource(&wg) + go SrcLogin.LoadSource(&wg) + go SrcPassword.LoadSource(&wg) wg.Wait() - SrcIP.reportLoaded() - SrcLogin.reportLoaded() - SrcPassword.reportLoaded() + SrcIP.ReportLoaded() + SrcLogin.ReportLoaded() + SrcPassword.ReportLoaded() ParseEndpoints(SrcIP.plain) - - // TODO: dynamic loading of contents - // currently this just dumps everything to a slice ParseEndpoints(SrcIP.contents) log("src", 1, "ok: finished loading sources") } +// CloseSources closes all source files. func CloseSources() { log("src", 1, "closing sources") - SrcIP.closeFiles() - SrcLogin.closeFiles() - SrcPassword.closeFiles() + + SrcIP.CloseSource() + SrcLogin.CloseSource() + SrcPassword.CloseSource() + log("src", 1, "ok: finished closing sources") } func init() { - CfgRegister("ip", []string{}, "IPs or subnets in CIDR notation") - CfgRegister("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") - CfgRegister("login", []string{}, "one or more logins") - CfgRegister("login-file", []string{}, "paths to files with logins (one entry per line)") - CfgRegister("password", []string{}, "one or more passwords") - CfgRegister("password-file", []string{}, "paths to files with passwords (one entry per line)") - - CfgRegisterSwitch("add-empty-password", "insert an empty password to the password list") - //CfgRegisterSwitch("no-source-validation", "do not attempt to validate and count lines in source files") - CfgRegisterSwitch("no-password-trim", "preserve leading and trailing spaces in passwords") - CfgRegisterSwitch("logins-first", "increment logins before passwords") - CfgRegisterSwitch("file-contents-first", "try to go through source files first, defer commandline args for later") - - // - //CfgRegisterSwitch("ignore-network-names", "always skip non-IPv4/non-IPv6 entries, do not attempt to resolve them as domain names") - // + registerParam("ip", []string{}, "IPs or subnets in CIDR notation") + registerParam("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") + registerParam("login", []string{}, "one or more logins") + registerParam("login-file", []string{}, "paths to files with logins (one entry per line)") + registerParam("password", []string{}, "one or more passwords") + registerParam("password-file", []string{}, "paths to files with passwords (one entry per line)") + + registerSwitch("add-empty-password", "insert an empty password to the password list") + registerSwitch("no-password-trim", "preserve leading and trailing spaces in passwords") + registerSwitch("logins-first", "increment logins before passwords") + registerSwitch("file-contents-first", "try to go through source files first, defer commandline args for later") } diff --git a/winbox.go b/winbox.go index dc1d430..26b9c43 100644 --- a/winbox.go +++ b/winbox.go @@ -23,7 +23,7 @@ func NewWinbox(task *Task, conn net.Conn) *Winbox { winbox.conn = conn winbox.xwaParity = false winbox.xwbParity = false - winbox.w = NewWCurve() + winbox.w = newWCurve() winbox.stage = -1 winbox.user = task.login winbox.pass = task.password