1
// Copyright © 2017-2020 Trust Wallet.
2
//
3
// This file is part of Trust. The full Trust copyright notice, including
4
// terms governing use, modification, and redistribution, is contained in the
5
// file LICENSE at the root of the source code distribution tree.
6

7
#include "HDWallet.h"
8

9
#include "Base58.h"
10
#include "BinaryCoding.h"
11
#include "Bitcoin/SegwitAddress.h"
12
#include "Bitcoin/CashAddress.h"
13
#include "Coin.h"
14

15
#include <TrezorCrypto/bip32.h>
16
#include <TrezorCrypto/bip39.h>
17
#include <TrezorCrypto/curves.h>
18
#include <TrustWalletCore/TWHRP.h>
19

20
#include <array>
21

22
using namespace TW;
23

24
namespace {
25

26
uint32_t fingerprint(HDNode *node, Hash::Hasher hasher);
27
std::string serialize(const HDNode *node, uint32_t fingerprint, uint32_t version, bool use_public, Hash::Hasher hasher);
28
bool deserialize(const std::string& extended, TWCurve curve, Hash::Hasher hasher, HDNode *node);
29
HDNode getNode(const HDWallet& wallet, TWCurve curve, const DerivationPath& derivationPath);
30
HDNode getMasterNode(const HDWallet& wallet, TWCurve curve);
31

32
const char* curveName(TWCurve curve);
33
} // namespace
34

35 1
bool HDWallet::isValid(const std::string& mnemonic) {
36 1
    return mnemonic_check(mnemonic.c_str()) != 0;
37
}
38

39 1
HDWallet::HDWallet(int strength, const std::string& passphrase)
40 1
    : seed(), mnemonic(), passphrase(passphrase) {
41
    std::array<char, HDWallet::maxMnemomincSize> mnemonic_chars;
42 1
    mnemonic_generate(strength, mnemonic_chars.data());
43 1
    mnemonic_to_seed(mnemonic_chars.data(), passphrase.c_str(), seed.data(), nullptr);
44 1
    mnemonic = mnemonic_chars.data();
45 1
    updateEntropy();
46
}
47

48 1
HDWallet::HDWallet(const std::string& mnemonic, const std::string& passphrase)
49 1
    : seed(), mnemonic(mnemonic), passphrase(passphrase) {
50 1
    mnemonic_to_seed(mnemonic.c_str(), passphrase.c_str(), seed.data(), nullptr);
51 1
    updateEntropy();
52
}
53

54 1
HDWallet::HDWallet(const Data& data, const std::string& passphrase)
55 1
    : seed(), mnemonic(), passphrase(passphrase) {
56
    std::array<char, HDWallet::maxMnemomincSize> mnemonic_chars;
57 1
    if (mnemonic_from_data(data.data(), data.size(), mnemonic_chars.data())) {
58 1
        mnemonic_to_seed(mnemonic_chars.data(), passphrase.c_str(), seed.data(), nullptr);
59 1
        mnemonic = mnemonic_chars.data();
60 1
        updateEntropy();
61
    }
62
}
63

64 1
HDWallet::~HDWallet() {
65 1
    std::fill(seed.begin(), seed.end(), 0);
66 1
    std::fill(mnemonic.begin(), mnemonic.end(), 0);
67 1
    std::fill(passphrase.begin(), passphrase.end(), 0);
68
}
69

70 1
void HDWallet::updateEntropy() {
71
    // generate entropy (from mnemonic)
72 1
    Data entropyRaw(32 + 1);
73 1
    auto entropyBits = mnemonic_to_entropy(mnemonic.c_str(), entropyRaw.data());
74
    // copy to truncate
75 1
    entropy = data(entropyRaw.data(), entropyBits / 8);
76
}
77

78 1
PrivateKey HDWallet::getMasterKey(TWCurve curve) const {
79 1
    auto node = getMasterNode(*this, curve);
80 1
    auto data = Data(node.private_key, node.private_key + PrivateKey::size);
81 1
    return PrivateKey(data);
82
}
83

84 1
PrivateKey HDWallet::getMasterKeyExtension(TWCurve curve) const {
85 1
    auto node = getMasterNode(*this, curve);
86 1
    auto data = Data(node.private_key_extension, node.private_key_extension + PrivateKey::size);
87 1
    return PrivateKey(data);
88
}
89

90 1
PrivateKey HDWallet::getKey(TWCoinType coin, const DerivationPath& derivationPath) const {
91 1
    const auto curve = TWCoinTypeCurve(coin);
92 1
    const auto privateKeyType = getPrivateKeyType(curve);
93 1
    auto node = getNode(*this, curve, derivationPath);
94 1
    switch (privateKeyType) {
95
        case PrivateKeyTypeExtended96:
96
            {
97 1
                auto pkData = Data(node.private_key, node.private_key + PrivateKey::size);
98 1
                auto extData = Data(node.private_key_extension, node.private_key_extension + PrivateKey::size);
99 1
                auto chainCode = Data(node.chain_code, node.chain_code + PrivateKey::size);
100 1
                return PrivateKey(pkData, extData, chainCode);
101
            }
102

103
        case PrivateKeyTypeDefault32:
104
        default:
105
            // default path
106 1
            auto data = Data(node.private_key, node.private_key + PrivateKey::size);
107 1
            return PrivateKey(data);
108
    }
109
}
110

111 1
std::string HDWallet::deriveAddress(TWCoinType coin) const {
112 1
    const auto derivationPath = TW::derivationPath(coin);
113 1
    return TW::deriveAddress(coin, getKey(coin, derivationPath));
114
}
115

116 1
std::string HDWallet::getExtendedPrivateKey(TWPurpose purpose, TWCoinType coin, TWHDVersion version) const {
117 1
    if (version == TWHDVersionNone) {
118 1
        return "";
119
    }
120
    
121 1
    const auto curve = TWCoinTypeCurve(coin);
122 1
    auto derivationPath = TW::DerivationPath({DerivationPathIndex(purpose, true), DerivationPathIndex(coin, true)});
123 1
    auto node = getNode(*this, curve, derivationPath);
124 1
    auto fingerprintValue = fingerprint(&node, publicKeyHasher(coin));
125 1
    hdnode_private_ckd(&node, 0x80000000);
126 1
    return serialize(&node, fingerprintValue, version, false, base58Hasher(coin));
127
}
128

129 1
std::string HDWallet::getExtendedPublicKey(TWPurpose purpose, TWCoinType coin, TWHDVersion version) const {
130 1
    if (version == TWHDVersionNone) {
131 1
        return "";
132
    }
133
    
134 1
    const auto curve = TWCoinTypeCurve(coin);
135 1
    auto derivationPath = TW::DerivationPath({DerivationPathIndex(purpose, true), DerivationPathIndex(coin, true)});
136 1
    auto node = getNode(*this, curve, derivationPath);
137 1
    auto fingerprintValue = fingerprint(&node, publicKeyHasher(coin));
138 1
    hdnode_private_ckd(&node, 0x80000000);
139 1
    hdnode_fill_public_key(&node);
140 1
    return serialize(&node, fingerprintValue, version, true, base58Hasher(coin));
141
}
142

143 1
std::optional<PublicKey> HDWallet::getPublicKeyFromExtended(const std::string& extended, TWCoinType coin, const DerivationPath& path) {
144 1
    const auto curve = TW::curve(coin);
145 1
    const auto hasher = TW::base58Hasher(coin);
146

147 1
    auto node = HDNode{};
148 1
    if (!deserialize(extended, curve, hasher, &node)) {
149 1
        return {};
150
    }
151 1
    if (node.curve->params == nullptr) {
152 1
        return {};
153
    }
154 1
    hdnode_public_ckd(&node, path.change());
155 1
    hdnode_public_ckd(&node, path.address());
156 1
    hdnode_fill_public_key(&node);
157

158
    // These public key type are not applicable.  Handled above, as node.curve->params is null
159 1
    assert(curve != TWCurveED25519 && curve != TWCurveED25519Blake2bNano && curve != TWCurveED25519Extended && curve != TWCurveCurve25519);
160 1
    TWPublicKeyType keyType = TW::publicKeyType(coin);
161 1
    if (curve == TWCurveSECP256k1 && keyType == TWPublicKeyTypeSECP256k1) {
162 1
        return PublicKey(Data(node.public_key, node.public_key + 33), TWPublicKeyTypeSECP256k1);
163 1
    } else if (curve == TWCurveNIST256p1 && keyType == TWPublicKeyTypeNIST256p1) {
164 1
        return PublicKey(Data(node.public_key, node.public_key + 33), TWPublicKeyTypeNIST256p1);
165
    }
166 0
    return {};
167
}
168

169 1
std::optional<PrivateKey> HDWallet::getPrivateKeyFromExtended(const std::string& extended, TWCoinType coin, const DerivationPath& path) {
170 1
    const auto curve = TW::curve(coin);
171 1
    const auto hasher = TW::base58Hasher(coin);
172

173 1
    auto node = HDNode{};
174 1
    if (!deserialize(extended, curve, hasher, &node)) {
175 1
        return {};
176
    }
177 1
    hdnode_private_ckd(&node, path.change());
178 1
    hdnode_private_ckd(&node, path.address());
179

180 1
    return PrivateKey(Data(node.private_key, node.private_key + 32));
181
}
182

183 1
HDWallet::PrivateKeyType HDWallet::getPrivateKeyType(TWCurve curve) {
184 1
    if (curve == TWCurve::TWCurveED25519Extended) {
185
        // used by Cardano
186 1
        return PrivateKeyTypeExtended96;
187
    }
188
    // default
189 1
    return PrivateKeyTypeDefault32;
190
}
191

192
namespace {
193

194 1
uint32_t fingerprint(HDNode *node, Hash::Hasher hasher) {
195 1
    hdnode_fill_public_key(node);
196 1
    auto digest = hasher(node->public_key, 33);
197 1
    return ((uint32_t) digest[0] << 24) + (digest[1] << 16) + (digest[2] << 8) + digest[3];
198
}
199

200 1
std::string serialize(const HDNode *node, uint32_t fingerprint, uint32_t version, bool use_public, Hash::Hasher hasher) {
201 1
    Data node_data;
202 1
    node_data.reserve(78);
203

204 1
    encode32BE(version, node_data);
205 1
    node_data.push_back(static_cast<uint8_t>(node->depth));
206 1
    encode32BE(fingerprint, node_data);
207 1
    encode32BE(node->child_num, node_data);
208 1
    node_data.insert(node_data.end(), node->chain_code, node->chain_code + 32);
209 1
    if (use_public) {
210 1
        node_data.insert(node_data.end(), node->public_key, node->public_key + 33);
211 1
    } else {
212 1
        node_data.push_back(0);
213 1
        node_data.insert(node_data.end(), node->private_key, node->private_key + 32);
214
    }
215

216 1
    return Base58::bitcoin.encodeCheck(node_data, hasher);
217
}
218

219 1
bool deserialize(const std::string& extended, TWCurve curve, Hash::Hasher hasher, HDNode* node) {
220 1
    memset(node, 0, sizeof(HDNode));
221 1
    const char* curveNameStr = curveName(curve);
222 1
    if (curveNameStr == nullptr || ::strlen(curveNameStr) == 0) {
223 1
        return false;
224
    }
225 1
    node->curve = get_curve_by_name(curveNameStr);
226 1
    assert(node->curve != nullptr);
227

228 1
    const auto node_data = Base58::bitcoin.decodeCheck(extended, hasher);
229 1
    if (node_data.size() != 78) {
230 1
        return false;
231
    }
232

233 1
    uint32_t version = decode32BE(node_data.data());
234 1
    if (TWHDVersionIsPublic(static_cast<TWHDVersion>(version))) {
235 1
        std::copy(node_data.begin() + 45, node_data.begin() + 45 + 33, node->public_key);
236 1
    } else if (TWHDVersionIsPrivate(static_cast<TWHDVersion>(version))) {
237 1
        if (node_data[45]) { // invalid data
238 1
            return false;
239
        }
240 1
        std::copy(node_data.begin() + 46, node_data.begin() + 46 + 32, node->private_key);
241 1
    } else {
242 1
        return false; // invalid version
243
    }
244 1
    node->depth = node_data[4];
245 1
    node->child_num = decode32BE(node_data.data() + 9);
246 1
    std::copy(node_data.begin() + 13, node_data.begin() + 13 + 32, node->chain_code);
247 1
    return true;
248
}
249

250 1
HDNode getNode(const HDWallet& wallet, TWCurve curve, const DerivationPath& derivationPath) {
251 1
    const auto privateKeyType = HDWallet::getPrivateKeyType(curve);
252 1
    auto node = getMasterNode(wallet, curve);
253 1
    for (auto& index : derivationPath.indices) {
254 1
        switch (privateKeyType) {
255
            case HDWallet::PrivateKeyTypeExtended96:
256
                // special handling for extended
257 1
                hdnode_private_ckd_cardano(&node, index.derivationIndex());
258 1
                break;
259
            case HDWallet::PrivateKeyTypeDefault32:
260
            default:
261 1
                hdnode_private_ckd(&node, index.derivationIndex());
262 1
                break;
263
        }
264
    }
265 1
    return node;
266
}
267

268 1
HDNode getMasterNode(const HDWallet& wallet, TWCurve curve) {
269 1
    const auto privateKeyType = HDWallet::getPrivateKeyType(curve);
270 1
    auto node = HDNode();
271 1
    switch (privateKeyType) {
272
        case HDWallet::PrivateKeyTypeExtended96:
273
            // special handling for extended, use entropy (not seed)
274 1
            hdnode_from_seed_cardano((const uint8_t*)"", 0, wallet.entropy.data(), (int)wallet.entropy.size(), &node);
275 1
            break;
276
        case HDWallet::PrivateKeyTypeDefault32:
277
        default:
278 1
            hdnode_from_seed(wallet.seed.data(), HDWallet::seedSize, curveName(curve), &node);
279 1
            break;
280
    }
281 1
    return node;
282
}
283

284 1
const char* curveName(TWCurve curve) {
285 1
    switch (curve) {
286
    case TWCurveSECP256k1:
287 1
        return SECP256K1_NAME;
288
    case TWCurveED25519:
289 1
        return ED25519_NAME;
290
    case TWCurveED25519Blake2bNano:
291 1
        return ED25519_BLAKE2B_NANO_NAME;
292
    case TWCurveED25519Extended:
293 1
        return ED25519_CARDANO_NAME;
294
    case TWCurveNIST256p1:
295 1
        return NIST256P1_NAME;
296
    case TWCurveCurve25519:
297 1
        return CURVE25519_NAME;
298
    case TWCurveNone:
299
    default:
300 1
        return "";
301
    }
302
}
303

304
} // namespace

Read our documentation on viewing source code .

Loading