|  | // Copyright 2012 The Go Authors. All rights reserved. | 
|  | // Use of this source code is governed by a BSD-style | 
|  | // license that can be found in the LICENSE file. | 
|  |  | 
|  | // This file implements the Socialist Millionaires Protocol as described in | 
|  | // http://www.cypherpunks.ca/otr/Protocol-v2-3.1.0.html. The protocol | 
|  | // specification is required in order to understand this code and, where | 
|  | // possible, the variable names in the code match up with the spec. | 
|  |  | 
|  | package otr | 
|  |  | 
|  | import ( | 
|  | "bytes" | 
|  | "crypto/sha256" | 
|  | "errors" | 
|  | "hash" | 
|  | "math/big" | 
|  | ) | 
|  |  | 
|  | type smpFailure string | 
|  |  | 
|  | func (s smpFailure) Error() string { | 
|  | return string(s) | 
|  | } | 
|  |  | 
|  | var smpFailureError = smpFailure("otr: SMP protocol failed") | 
|  | var smpSecretMissingError = smpFailure("otr: mutual secret needed") | 
|  |  | 
|  | const smpVersion = 1 | 
|  |  | 
|  | const ( | 
|  | smpState1 = iota | 
|  | smpState2 | 
|  | smpState3 | 
|  | smpState4 | 
|  | ) | 
|  |  | 
|  | type smpState struct { | 
|  | state                  int | 
|  | a2, a3, b2, b3, pb, qb *big.Int | 
|  | g2a, g3a               *big.Int | 
|  | g2, g3                 *big.Int | 
|  | g3b, papb, qaqb, ra    *big.Int | 
|  | saved                  *tlv | 
|  | secret                 *big.Int | 
|  | question               string | 
|  | } | 
|  |  | 
|  | func (c *Conversation) startSMP(question string) (tlvs []tlv) { | 
|  | if c.smp.state != smpState1 { | 
|  | tlvs = append(tlvs, c.generateSMPAbort()) | 
|  | } | 
|  | tlvs = append(tlvs, c.generateSMP1(question)) | 
|  | c.smp.question = "" | 
|  | c.smp.state = smpState2 | 
|  | return | 
|  | } | 
|  |  | 
|  | func (c *Conversation) resetSMP() { | 
|  | c.smp.state = smpState1 | 
|  | c.smp.secret = nil | 
|  | c.smp.question = "" | 
|  | } | 
|  |  | 
|  | func (c *Conversation) processSMP(in tlv) (out tlv, complete bool, err error) { | 
|  | data := in.data | 
|  |  | 
|  | switch in.typ { | 
|  | case tlvTypeSMPAbort: | 
|  | if c.smp.state != smpState1 { | 
|  | err = smpFailureError | 
|  | } | 
|  | c.resetSMP() | 
|  | return | 
|  | case tlvTypeSMP1WithQuestion: | 
|  | // We preprocess this into a SMP1 message. | 
|  | nulPos := bytes.IndexByte(data, 0) | 
|  | if nulPos == -1 { | 
|  | err = errors.New("otr: SMP message with question didn't contain a NUL byte") | 
|  | return | 
|  | } | 
|  | c.smp.question = string(data[:nulPos]) | 
|  | data = data[nulPos+1:] | 
|  | } | 
|  |  | 
|  | numMPIs, data, ok := getU32(data) | 
|  | if !ok || numMPIs > 20 { | 
|  | err = errors.New("otr: corrupt SMP message") | 
|  | return | 
|  | } | 
|  |  | 
|  | mpis := make([]*big.Int, numMPIs) | 
|  | for i := range mpis { | 
|  | var ok bool | 
|  | mpis[i], data, ok = getMPI(data) | 
|  | if !ok { | 
|  | err = errors.New("otr: corrupt SMP message") | 
|  | return | 
|  | } | 
|  | } | 
|  |  | 
|  | switch in.typ { | 
|  | case tlvTypeSMP1, tlvTypeSMP1WithQuestion: | 
|  | if c.smp.state != smpState1 { | 
|  | c.resetSMP() | 
|  | out = c.generateSMPAbort() | 
|  | return | 
|  | } | 
|  | if c.smp.secret == nil { | 
|  | err = smpSecretMissingError | 
|  | return | 
|  | } | 
|  | if err = c.processSMP1(mpis); err != nil { | 
|  | return | 
|  | } | 
|  | c.smp.state = smpState3 | 
|  | out = c.generateSMP2() | 
|  | case tlvTypeSMP2: | 
|  | if c.smp.state != smpState2 { | 
|  | c.resetSMP() | 
|  | out = c.generateSMPAbort() | 
|  | return | 
|  | } | 
|  | if out, err = c.processSMP2(mpis); err != nil { | 
|  | out = c.generateSMPAbort() | 
|  | return | 
|  | } | 
|  | c.smp.state = smpState4 | 
|  | case tlvTypeSMP3: | 
|  | if c.smp.state != smpState3 { | 
|  | c.resetSMP() | 
|  | out = c.generateSMPAbort() | 
|  | return | 
|  | } | 
|  | if out, err = c.processSMP3(mpis); err != nil { | 
|  | return | 
|  | } | 
|  | c.smp.state = smpState1 | 
|  | c.smp.secret = nil | 
|  | complete = true | 
|  | case tlvTypeSMP4: | 
|  | if c.smp.state != smpState4 { | 
|  | c.resetSMP() | 
|  | out = c.generateSMPAbort() | 
|  | return | 
|  | } | 
|  | if err = c.processSMP4(mpis); err != nil { | 
|  | out = c.generateSMPAbort() | 
|  | return | 
|  | } | 
|  | c.smp.state = smpState1 | 
|  | c.smp.secret = nil | 
|  | complete = true | 
|  | default: | 
|  | panic("unknown SMP message") | 
|  | } | 
|  |  | 
|  | return | 
|  | } | 
|  |  | 
|  | func (c *Conversation) calcSMPSecret(mutualSecret []byte, weStarted bool) { | 
|  | h := sha256.New() | 
|  | h.Write([]byte{smpVersion}) | 
|  | if weStarted { | 
|  | h.Write(c.PrivateKey.PublicKey.Fingerprint()) | 
|  | h.Write(c.TheirPublicKey.Fingerprint()) | 
|  | } else { | 
|  | h.Write(c.TheirPublicKey.Fingerprint()) | 
|  | h.Write(c.PrivateKey.PublicKey.Fingerprint()) | 
|  | } | 
|  | h.Write(c.SSID[:]) | 
|  | h.Write(mutualSecret) | 
|  | c.smp.secret = new(big.Int).SetBytes(h.Sum(nil)) | 
|  | } | 
|  |  | 
|  | func (c *Conversation) generateSMP1(question string) tlv { | 
|  | var randBuf [16]byte | 
|  | c.smp.a2 = c.randMPI(randBuf[:]) | 
|  | c.smp.a3 = c.randMPI(randBuf[:]) | 
|  | g2a := new(big.Int).Exp(g, c.smp.a2, p) | 
|  | g3a := new(big.Int).Exp(g, c.smp.a3, p) | 
|  | h := sha256.New() | 
|  |  | 
|  | r2 := c.randMPI(randBuf[:]) | 
|  | r := new(big.Int).Exp(g, r2, p) | 
|  | c2 := new(big.Int).SetBytes(hashMPIs(h, 1, r)) | 
|  | d2 := new(big.Int).Mul(c.smp.a2, c2) | 
|  | d2.Sub(r2, d2) | 
|  | d2.Mod(d2, q) | 
|  | if d2.Sign() < 0 { | 
|  | d2.Add(d2, q) | 
|  | } | 
|  |  | 
|  | r3 := c.randMPI(randBuf[:]) | 
|  | r.Exp(g, r3, p) | 
|  | c3 := new(big.Int).SetBytes(hashMPIs(h, 2, r)) | 
|  | d3 := new(big.Int).Mul(c.smp.a3, c3) | 
|  | d3.Sub(r3, d3) | 
|  | d3.Mod(d3, q) | 
|  | if d3.Sign() < 0 { | 
|  | d3.Add(d3, q) | 
|  | } | 
|  |  | 
|  | var ret tlv | 
|  | if len(question) > 0 { | 
|  | ret.typ = tlvTypeSMP1WithQuestion | 
|  | ret.data = append(ret.data, question...) | 
|  | ret.data = append(ret.data, 0) | 
|  | } else { | 
|  | ret.typ = tlvTypeSMP1 | 
|  | } | 
|  | ret.data = appendU32(ret.data, 6) | 
|  | ret.data = appendMPIs(ret.data, g2a, c2, d2, g3a, c3, d3) | 
|  | return ret | 
|  | } | 
|  |  | 
|  | func (c *Conversation) processSMP1(mpis []*big.Int) error { | 
|  | if len(mpis) != 6 { | 
|  | return errors.New("otr: incorrect number of arguments in SMP1 message") | 
|  | } | 
|  | g2a := mpis[0] | 
|  | c2 := mpis[1] | 
|  | d2 := mpis[2] | 
|  | g3a := mpis[3] | 
|  | c3 := mpis[4] | 
|  | d3 := mpis[5] | 
|  | h := sha256.New() | 
|  |  | 
|  | r := new(big.Int).Exp(g, d2, p) | 
|  | s := new(big.Int).Exp(g2a, c2, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  | t := new(big.Int).SetBytes(hashMPIs(h, 1, r)) | 
|  | if c2.Cmp(t) != 0 { | 
|  | return errors.New("otr: ZKP c2 incorrect in SMP1 message") | 
|  | } | 
|  | r.Exp(g, d3, p) | 
|  | s.Exp(g3a, c3, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  | t.SetBytes(hashMPIs(h, 2, r)) | 
|  | if c3.Cmp(t) != 0 { | 
|  | return errors.New("otr: ZKP c3 incorrect in SMP1 message") | 
|  | } | 
|  |  | 
|  | c.smp.g2a = g2a | 
|  | c.smp.g3a = g3a | 
|  | return nil | 
|  | } | 
|  |  | 
|  | func (c *Conversation) generateSMP2() tlv { | 
|  | var randBuf [16]byte | 
|  | b2 := c.randMPI(randBuf[:]) | 
|  | c.smp.b3 = c.randMPI(randBuf[:]) | 
|  | r2 := c.randMPI(randBuf[:]) | 
|  | r3 := c.randMPI(randBuf[:]) | 
|  | r4 := c.randMPI(randBuf[:]) | 
|  | r5 := c.randMPI(randBuf[:]) | 
|  | r6 := c.randMPI(randBuf[:]) | 
|  |  | 
|  | g2b := new(big.Int).Exp(g, b2, p) | 
|  | g3b := new(big.Int).Exp(g, c.smp.b3, p) | 
|  |  | 
|  | r := new(big.Int).Exp(g, r2, p) | 
|  | h := sha256.New() | 
|  | c2 := new(big.Int).SetBytes(hashMPIs(h, 3, r)) | 
|  | d2 := new(big.Int).Mul(b2, c2) | 
|  | d2.Sub(r2, d2) | 
|  | d2.Mod(d2, q) | 
|  | if d2.Sign() < 0 { | 
|  | d2.Add(d2, q) | 
|  | } | 
|  |  | 
|  | r.Exp(g, r3, p) | 
|  | c3 := new(big.Int).SetBytes(hashMPIs(h, 4, r)) | 
|  | d3 := new(big.Int).Mul(c.smp.b3, c3) | 
|  | d3.Sub(r3, d3) | 
|  | d3.Mod(d3, q) | 
|  | if d3.Sign() < 0 { | 
|  | d3.Add(d3, q) | 
|  | } | 
|  |  | 
|  | c.smp.g2 = new(big.Int).Exp(c.smp.g2a, b2, p) | 
|  | c.smp.g3 = new(big.Int).Exp(c.smp.g3a, c.smp.b3, p) | 
|  | c.smp.pb = new(big.Int).Exp(c.smp.g3, r4, p) | 
|  | c.smp.qb = new(big.Int).Exp(g, r4, p) | 
|  | r.Exp(c.smp.g2, c.smp.secret, p) | 
|  | c.smp.qb.Mul(c.smp.qb, r) | 
|  | c.smp.qb.Mod(c.smp.qb, p) | 
|  |  | 
|  | s := new(big.Int) | 
|  | s.Exp(c.smp.g2, r6, p) | 
|  | r.Exp(g, r5, p) | 
|  | s.Mul(r, s) | 
|  | s.Mod(s, p) | 
|  | r.Exp(c.smp.g3, r5, p) | 
|  | cp := new(big.Int).SetBytes(hashMPIs(h, 5, r, s)) | 
|  |  | 
|  | // D5 = r5 - r4 cP mod q and D6 = r6 - y cP mod q | 
|  |  | 
|  | s.Mul(r4, cp) | 
|  | r.Sub(r5, s) | 
|  | d5 := new(big.Int).Mod(r, q) | 
|  | if d5.Sign() < 0 { | 
|  | d5.Add(d5, q) | 
|  | } | 
|  |  | 
|  | s.Mul(c.smp.secret, cp) | 
|  | r.Sub(r6, s) | 
|  | d6 := new(big.Int).Mod(r, q) | 
|  | if d6.Sign() < 0 { | 
|  | d6.Add(d6, q) | 
|  | } | 
|  |  | 
|  | var ret tlv | 
|  | ret.typ = tlvTypeSMP2 | 
|  | ret.data = appendU32(ret.data, 11) | 
|  | ret.data = appendMPIs(ret.data, g2b, c2, d2, g3b, c3, d3, c.smp.pb, c.smp.qb, cp, d5, d6) | 
|  | return ret | 
|  | } | 
|  |  | 
|  | func (c *Conversation) processSMP2(mpis []*big.Int) (out tlv, err error) { | 
|  | if len(mpis) != 11 { | 
|  | err = errors.New("otr: incorrect number of arguments in SMP2 message") | 
|  | return | 
|  | } | 
|  | g2b := mpis[0] | 
|  | c2 := mpis[1] | 
|  | d2 := mpis[2] | 
|  | g3b := mpis[3] | 
|  | c3 := mpis[4] | 
|  | d3 := mpis[5] | 
|  | pb := mpis[6] | 
|  | qb := mpis[7] | 
|  | cp := mpis[8] | 
|  | d5 := mpis[9] | 
|  | d6 := mpis[10] | 
|  | h := sha256.New() | 
|  |  | 
|  | r := new(big.Int).Exp(g, d2, p) | 
|  | s := new(big.Int).Exp(g2b, c2, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  | s.SetBytes(hashMPIs(h, 3, r)) | 
|  | if c2.Cmp(s) != 0 { | 
|  | err = errors.New("otr: ZKP c2 failed in SMP2 message") | 
|  | return | 
|  | } | 
|  |  | 
|  | r.Exp(g, d3, p) | 
|  | s.Exp(g3b, c3, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  | s.SetBytes(hashMPIs(h, 4, r)) | 
|  | if c3.Cmp(s) != 0 { | 
|  | err = errors.New("otr: ZKP c3 failed in SMP2 message") | 
|  | return | 
|  | } | 
|  |  | 
|  | c.smp.g2 = new(big.Int).Exp(g2b, c.smp.a2, p) | 
|  | c.smp.g3 = new(big.Int).Exp(g3b, c.smp.a3, p) | 
|  |  | 
|  | r.Exp(g, d5, p) | 
|  | s.Exp(c.smp.g2, d6, p) | 
|  | r.Mul(r, s) | 
|  | s.Exp(qb, cp, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  |  | 
|  | s.Exp(c.smp.g3, d5, p) | 
|  | t := new(big.Int).Exp(pb, cp, p) | 
|  | s.Mul(s, t) | 
|  | s.Mod(s, p) | 
|  | t.SetBytes(hashMPIs(h, 5, s, r)) | 
|  | if cp.Cmp(t) != 0 { | 
|  | err = errors.New("otr: ZKP cP failed in SMP2 message") | 
|  | return | 
|  | } | 
|  |  | 
|  | var randBuf [16]byte | 
|  | r4 := c.randMPI(randBuf[:]) | 
|  | r5 := c.randMPI(randBuf[:]) | 
|  | r6 := c.randMPI(randBuf[:]) | 
|  | r7 := c.randMPI(randBuf[:]) | 
|  |  | 
|  | pa := new(big.Int).Exp(c.smp.g3, r4, p) | 
|  | r.Exp(c.smp.g2, c.smp.secret, p) | 
|  | qa := new(big.Int).Exp(g, r4, p) | 
|  | qa.Mul(qa, r) | 
|  | qa.Mod(qa, p) | 
|  |  | 
|  | r.Exp(g, r5, p) | 
|  | s.Exp(c.smp.g2, r6, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  |  | 
|  | s.Exp(c.smp.g3, r5, p) | 
|  | cp.SetBytes(hashMPIs(h, 6, s, r)) | 
|  |  | 
|  | r.Mul(r4, cp) | 
|  | d5 = new(big.Int).Sub(r5, r) | 
|  | d5.Mod(d5, q) | 
|  | if d5.Sign() < 0 { | 
|  | d5.Add(d5, q) | 
|  | } | 
|  |  | 
|  | r.Mul(c.smp.secret, cp) | 
|  | d6 = new(big.Int).Sub(r6, r) | 
|  | d6.Mod(d6, q) | 
|  | if d6.Sign() < 0 { | 
|  | d6.Add(d6, q) | 
|  | } | 
|  |  | 
|  | r.ModInverse(qb, p) | 
|  | qaqb := new(big.Int).Mul(qa, r) | 
|  | qaqb.Mod(qaqb, p) | 
|  |  | 
|  | ra := new(big.Int).Exp(qaqb, c.smp.a3, p) | 
|  | r.Exp(qaqb, r7, p) | 
|  | s.Exp(g, r7, p) | 
|  | cr := new(big.Int).SetBytes(hashMPIs(h, 7, s, r)) | 
|  |  | 
|  | r.Mul(c.smp.a3, cr) | 
|  | d7 := new(big.Int).Sub(r7, r) | 
|  | d7.Mod(d7, q) | 
|  | if d7.Sign() < 0 { | 
|  | d7.Add(d7, q) | 
|  | } | 
|  |  | 
|  | c.smp.g3b = g3b | 
|  | c.smp.qaqb = qaqb | 
|  |  | 
|  | r.ModInverse(pb, p) | 
|  | c.smp.papb = new(big.Int).Mul(pa, r) | 
|  | c.smp.papb.Mod(c.smp.papb, p) | 
|  | c.smp.ra = ra | 
|  |  | 
|  | out.typ = tlvTypeSMP3 | 
|  | out.data = appendU32(out.data, 8) | 
|  | out.data = appendMPIs(out.data, pa, qa, cp, d5, d6, ra, cr, d7) | 
|  | return | 
|  | } | 
|  |  | 
|  | func (c *Conversation) processSMP3(mpis []*big.Int) (out tlv, err error) { | 
|  | if len(mpis) != 8 { | 
|  | err = errors.New("otr: incorrect number of arguments in SMP3 message") | 
|  | return | 
|  | } | 
|  | pa := mpis[0] | 
|  | qa := mpis[1] | 
|  | cp := mpis[2] | 
|  | d5 := mpis[3] | 
|  | d6 := mpis[4] | 
|  | ra := mpis[5] | 
|  | cr := mpis[6] | 
|  | d7 := mpis[7] | 
|  | h := sha256.New() | 
|  |  | 
|  | r := new(big.Int).Exp(g, d5, p) | 
|  | s := new(big.Int).Exp(c.smp.g2, d6, p) | 
|  | r.Mul(r, s) | 
|  | s.Exp(qa, cp, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  |  | 
|  | s.Exp(c.smp.g3, d5, p) | 
|  | t := new(big.Int).Exp(pa, cp, p) | 
|  | s.Mul(s, t) | 
|  | s.Mod(s, p) | 
|  | t.SetBytes(hashMPIs(h, 6, s, r)) | 
|  | if t.Cmp(cp) != 0 { | 
|  | err = errors.New("otr: ZKP cP failed in SMP3 message") | 
|  | return | 
|  | } | 
|  |  | 
|  | r.ModInverse(c.smp.qb, p) | 
|  | qaqb := new(big.Int).Mul(qa, r) | 
|  | qaqb.Mod(qaqb, p) | 
|  |  | 
|  | r.Exp(qaqb, d7, p) | 
|  | s.Exp(ra, cr, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  |  | 
|  | s.Exp(g, d7, p) | 
|  | t.Exp(c.smp.g3a, cr, p) | 
|  | s.Mul(s, t) | 
|  | s.Mod(s, p) | 
|  | t.SetBytes(hashMPIs(h, 7, s, r)) | 
|  | if t.Cmp(cr) != 0 { | 
|  | err = errors.New("otr: ZKP cR failed in SMP3 message") | 
|  | return | 
|  | } | 
|  |  | 
|  | var randBuf [16]byte | 
|  | r7 := c.randMPI(randBuf[:]) | 
|  | rb := new(big.Int).Exp(qaqb, c.smp.b3, p) | 
|  |  | 
|  | r.Exp(qaqb, r7, p) | 
|  | s.Exp(g, r7, p) | 
|  | cr = new(big.Int).SetBytes(hashMPIs(h, 8, s, r)) | 
|  |  | 
|  | r.Mul(c.smp.b3, cr) | 
|  | d7 = new(big.Int).Sub(r7, r) | 
|  | d7.Mod(d7, q) | 
|  | if d7.Sign() < 0 { | 
|  | d7.Add(d7, q) | 
|  | } | 
|  |  | 
|  | out.typ = tlvTypeSMP4 | 
|  | out.data = appendU32(out.data, 3) | 
|  | out.data = appendMPIs(out.data, rb, cr, d7) | 
|  |  | 
|  | r.ModInverse(c.smp.pb, p) | 
|  | r.Mul(pa, r) | 
|  | r.Mod(r, p) | 
|  | s.Exp(ra, c.smp.b3, p) | 
|  | if r.Cmp(s) != 0 { | 
|  | err = smpFailureError | 
|  | } | 
|  |  | 
|  | return | 
|  | } | 
|  |  | 
|  | func (c *Conversation) processSMP4(mpis []*big.Int) error { | 
|  | if len(mpis) != 3 { | 
|  | return errors.New("otr: incorrect number of arguments in SMP4 message") | 
|  | } | 
|  | rb := mpis[0] | 
|  | cr := mpis[1] | 
|  | d7 := mpis[2] | 
|  | h := sha256.New() | 
|  |  | 
|  | r := new(big.Int).Exp(c.smp.qaqb, d7, p) | 
|  | s := new(big.Int).Exp(rb, cr, p) | 
|  | r.Mul(r, s) | 
|  | r.Mod(r, p) | 
|  |  | 
|  | s.Exp(g, d7, p) | 
|  | t := new(big.Int).Exp(c.smp.g3b, cr, p) | 
|  | s.Mul(s, t) | 
|  | s.Mod(s, p) | 
|  | t.SetBytes(hashMPIs(h, 8, s, r)) | 
|  | if t.Cmp(cr) != 0 { | 
|  | return errors.New("otr: ZKP cR failed in SMP4 message") | 
|  | } | 
|  |  | 
|  | r.Exp(rb, c.smp.a3, p) | 
|  | if r.Cmp(c.smp.papb) != 0 { | 
|  | return smpFailureError | 
|  | } | 
|  |  | 
|  | return nil | 
|  | } | 
|  |  | 
|  | func (c *Conversation) generateSMPAbort() tlv { | 
|  | return tlv{typ: tlvTypeSMPAbort} | 
|  | } | 
|  |  | 
|  | func hashMPIs(h hash.Hash, magic byte, mpis ...*big.Int) []byte { | 
|  | if h != nil { | 
|  | h.Reset() | 
|  | } else { | 
|  | h = sha256.New() | 
|  | } | 
|  |  | 
|  | h.Write([]byte{magic}) | 
|  | for _, mpi := range mpis { | 
|  | h.Write(appendMPI(nil, mpi)) | 
|  | } | 
|  | return h.Sum(nil) | 
|  | } |