Skip to content

Commit

Permalink
check multiple derivation paths when restoring (#792)
Browse files Browse the repository at this point in the history
* check multiple derivation paths when restoring

* non null assertion
  • Loading branch information
bryzettler authored Aug 15, 2024
1 parent 97e4e9c commit bd86027
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 27 deletions.
79 changes: 59 additions & 20 deletions src/features/onboarding/import/AccountImportScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@ import React, { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { FlatList } from 'react-native'
import { KeyboardAwareScrollView } from 'react-native-keyboard-aware-scroll-view'
import { MAIN_DERIVATION_PATHS } from '@hooks/useDerivationAccounts'
import { Keypair } from '@solana/web3.js'
import { CSAccount } from '@storage/cloudStorage'
import { useAccountStorage } from '../../../storage/AccountStorageProvider'
import {
DEFAULT_DERIVATION_PATH,
createKeypair,
toSecureAccount,
} from '../../../storage/secureStorage'
import { createKeypair, toSecureAccount } from '../../../storage/secureStorage'
import { useOnboarding } from '../OnboardingProvider'
import { OnboardingNavigationProp } from '../onboardingTypes'
import PassphraseAutocomplete from './PassphraseAutocomplete'
Expand Down Expand Up @@ -149,27 +148,68 @@ const AccountImportScreen = () => {

const handleNext = useCallback(async () => {
try {
let keypair: Keypair | undefined
const filteredWords: string[] = words.flatMap((w) => (w ? [w] : []))
const { keypair } = await createKeypair({
givenMnemonic: filteredWords,
use24Words: words?.length === 24,
derivationPath:
Object.values(accounts || {}).find(
(a) => a.address === accountAddress,
)?.derivationPath || DEFAULT_DERIVATION_PATH,
})
const foundDerivation = Object.values(accounts || {}).find(
(a) => a.address === accountAddress,
)?.derivationPath

if (foundDerivation) {
keypair = (
await createKeypair({
givenMnemonic: filteredWords,
use24Words: words?.length === 24,
derivationPath: foundDerivation,
})
).keypair
}

if (restoringAccount) {
let restoredAccount: CSAccount | undefined
if (!accounts || !accountAddress) {
await showOKAlert({
title: t('restoreAccount.errorAlert.title'),
message: t('restoreAccount.errorAlert.message'),
})
return
}
const restoredAccount = Object.values(accounts).find(
(a) => a.solanaAddress === keypair.publicKey.toBase58(),
)
if (!restoredAccount || accountAddress !== restoredAccount.address) {

if (keypair) {
restoredAccount = Object.values(accounts).find(
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
(a) => a.solanaAddress === keypair!.publicKey.toBase58(),
)
} else {
const keypairs = await Promise.all(
MAIN_DERIVATION_PATHS.map((dpath) =>
createKeypair({
givenMnemonic: filteredWords,
use24Words: words?.length === 24,
derivationPath: dpath,
}),
),
)

restoredAccount = Object.values(accounts).find((a) =>
keypairs.some(
(k) => a.solanaAddress === k.keypair.publicKey.toBase58(),
),
)

if (restoredAccount) {
keypair = keypairs.find(
(k) =>
restoredAccount?.solanaAddress ===
k.keypair.publicKey.toBase58(),
)?.keypair
}
}

if (
!keypair ||
!restoredAccount ||
accountAddress !== restoredAccount.address
) {
await showOKAlert({
title: t('restoreAccount.errorAlert.title'),
message: t('restoreAccount.errorAlert.message'),
Expand All @@ -183,10 +223,9 @@ const AccountImportScreen = () => {
})
reset()
navigation.popToTop()
} else {
setOnboardingData((prev) => ({ ...prev, words: filteredWords }))
navigation.navigate('ImportSubAccounts')
}
setOnboardingData((prev) => ({ ...prev, words: filteredWords }))
navigation.navigate('ImportSubAccounts')
} catch (error) {
await showOKAlert({
title: t('accountImport.alert.title'),
Expand Down
17 changes: 10 additions & 7 deletions src/hooks/useDerivationAccounts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ export type ResolvedPath = {
}

export const HELIUM_DERIVATION = 'Helium L1'
export const MAIN_DERIVATION_PATHS = [
HELIUM_DERIVATION,
heliumDerivation(-1),
solanaDerivation(-1, undefined),
]
export const useDerivationAccounts = ({ mnemonic }: { mnemonic?: string }) => {
const { connection } = useSolana()
const [resolvedGroups, setResolvedGroups] = useState<ResolvedPath[][]>([])
Expand All @@ -75,12 +80,6 @@ export const useDerivationAccounts = ({ mnemonic }: { mnemonic?: string }) => {
[resolvedGroups],
)

const mains = [
HELIUM_DERIVATION,
heliumDerivation(-1),
solanaDerivation(-1, undefined),
]

const solanaWithChange = (start: number, end: number) =>
new Array(end - start).fill(0).map((_, i) => solanaDerivation(i + start, 0))

Expand All @@ -90,7 +89,11 @@ export const useDerivationAccounts = ({ mnemonic }: { mnemonic?: string }) => {
.map((_, i) => solanaDerivation(i + start, undefined))

const [groups, setGroups] = useState([
[...mains, ...solanaWithChange(0, 10), ...solanaWithoutChange(0, 10)],
[
...MAIN_DERIVATION_PATHS,
...solanaWithChange(0, 10),
...solanaWithoutChange(0, 10),
],
])

// When mnemonic changes, reset resolved groups
Expand Down

0 comments on commit bd86027

Please sign in to comment.