From 3a4e9960e42030dd4721bd7e652e9b56e0bf19a5 Mon Sep 17 00:00:00 2001 From: Zeyun Date: Sun, 7 Jul 2024 17:44:42 -0700 Subject: [PATCH 1/4] update setup.cfg --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index 3585c61..bd5e3ef 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,7 +10,7 @@ author = Zeyun Lu, Nicholas Mancuso author_email = zeyunlu@usc.edu, Nicholas.Mancuso@med.usc.edu license = MIT license_files = LICENSE.txt -long_description = file: README.rst +long_description = file: README.md long_description_content_type = text/x-rst; charset=UTF-8 url = https://github.com/mancusolab/sushie # Add here related links, for example: From 219e19ce4484507eab86a4d3a95012d304e06ffb Mon Sep 17 00:00:00 2001 From: Zeyun Date: Mon, 8 Jul 2024 01:29:14 -0700 Subject: [PATCH 2/4] add remove ambiguous snps option and fix several small bugs. --- README.md | 8 +- data/make_example.py | 13 ++- data/prior_weights | 109 +++++++++++++++++++++ docs/manual.rst | 26 ++++- sushie/cli.py | 227 ++++++++++++++++++++++++++++++++++--------- sushie/infer.py | 29 ++++-- sushie/io.py | 18 ++-- 7 files changed, 352 insertions(+), 78 deletions(-) create mode 100644 data/prior_weights diff --git a/README.md b/README.md index 6a75fd5..58046fc 100644 --- a/README.md +++ b/README.md @@ -108,13 +108,13 @@ You can play it with your own ideas! | 0.13 | Add `--keep` command to enable user to specify a file that contains the subjects ID SuShiE will perform on. Add `--ancestry_index` command to enable user to specify a file that contains the ancestry index for fine-mapping. With this, user can input single phenotype, genotype, and covariate file that contains all the subjects across ancestries. Implement padding to increase inference time. Record elbo at each iteration and can access it in the `infer.SuShiEResult` object. The alphas table now outputs the average purity and KL divergence for each `L`. Change `--kl_threshold` to `--divergence`. Add `--maf` command to remove SNPs that less than minor allele frequency threshold within each ancestry. Add `--max_select` command to randomly select maximum number of SNPs to compute purity to avoid unnecessary memory spending. Add a QC function to remove duplicated SNPs. | | 0.14 | Remove KL-Divergence pruning. Enhance command line appearance and improve the output files contents. Fix small bugs on multivariate KL. | | 0.15 | Fix several typos; add a sanity check on reading vcf genotype data by assigning gt_types==Unknown as NA; Add preprint information. | +| 0.16 | Add option to remove ambiguous SNPs; fix several bugs and enhance codes quality. | ## Support -Please report any bugs or feature requests in the [Issue -Tracker](https://github.com/mancusolab/sushie/issues). If users have any -questions or comments, please contact Zeyun Lu () and -Nicholas Mancuso (). +For any questions, comments, bug reporting, and feature requests, please contact Zeyun Lu () and +Nicholas Mancuso (), and open a new thread in the [Issue +Tracker](https://github.com/mancusolab/sushie/issues). ## Other Software diff --git a/data/make_example.py b/data/make_example.py index 61c29ef..10ac159 100644 --- a/data/make_example.py +++ b/data/make_example.py @@ -10,6 +10,8 @@ # flip allele if necessary def _allele_check(baseA0, baseA1, compareA0, compareA1): + # no snps that have more than 2 alleles + # e.g., G and T for EUR and G and A for AFR correct = jnp.array( ((baseA0 == compareA0) * 1) * ((baseA1 == compareA1) * 1), dtype=int ) @@ -18,9 +20,8 @@ def _allele_check(baseA0, baseA1, compareA0, compareA1): ) correct_idx = jnp.where(correct == 1)[0] flipped_idx = jnp.where(flipped == 1)[0] - wrong_idx = jnp.where((correct + flipped) == 0)[0] - return correct_idx, flipped_idx, wrong_idx + return correct_idx, flipped_idx # read plink1.9 triplet @@ -51,14 +52,13 @@ def _allele_check(baseA0, baseA1, compareA0, compareA1): if n_pop > 1: for idx in range(1, n_pop): # keep track of miss match alleles - correct_idx, tmp_flip_idx, tmp_wrong_idx = _allele_check( + correct_idx, tmp_flip_idx = _allele_check( snps["a0_0"].values, snps["a1_0"].values, snps[f"a0_{idx}"].values, snps[f"a1_{idx}"].values, ) flip_idx.append(tmp_flip_idx) - snps = snps.drop(index=tmp_wrong_idx) snps = snps.drop(columns=[f"a0_{idx}", f"a1_{idx}"]) snps = snps.rename(columns={"a0_0": "a0", "a1_0": "a1"}) @@ -130,3 +130,8 @@ def _allele_check(baseA0, baseA1, compareA0, compareA1): pd.concat(fam_index)[["iid"]].iloc[sel_pt, :].to_csv( "./keep.subject", index=False, header=None, sep="\t" ) + +rng_key, unif_key = random.split(rng_key, 2) +unif = random.uniform(unif_key, shape=(snps.shape[0],), minval=5, maxval=200) +snps["unif"] = unif +snps[["snp", "unif"]].to_csv("./prior_weights", index=False, header=None, sep="\t") diff --git a/data/prior_weights b/data/prior_weights new file mode 100644 index 0000000..ba0d021 --- /dev/null +++ b/data/prior_weights @@ -0,0 +1,109 @@ +rs2294190 82.629944 +rs4653029 153.17215 +rs2075931 123.91783 +rs1883109 127.030655 +rs2038008 158.2589 +rs17385253 36.513393 +rs2142534 37.721756 +rs12098170 194.61343 +rs4653032 40.086555 +rs4653033 132.69426 +rs10799058 189.72482 +rs2075934 139.20802 +rs11578724 121.33664 +rs12134720 43.281784 +rs10489377 13.916 +rs4653036 11.430512 +rs10799059 176.69559 +rs2075933 61.103603 +rs12035126 64.29181 +rs4653037 24.734858 +rs1016091 189.90926 +rs6425895 107.337944 +rs7544781 167.4216 +rs4653042 181.94672 +rs10799060 135.90553 +rs12409604 191.37685 +rs750882 199.56975 +rs798076 72.80513 +rs798075 126.811356 +rs10914956 76.430916 +rs11586990 170.84335 +rs10914958 34.19725 +rs41416648 179.50003 +rs798066 77.906 +rs3856131 23.705788 +rs3905173 87.37802 +rs12047023 129.29813 +rs3845481 106.01135 +rs2359106 74.0428 +rs4652841 127.19247 +rs12027125 19.288433 +rs12566683 165.95 +rs6694043 137.73972 +rs4652846 187.0494 +rs12038357 58.336098 +rs1539374 83.38 +rs10493058 143.16483 +rs10914967 145.76859 +rs10737381 67.32597 +rs1418483 191.92468 +rs7539462 152.5103 +rs12136079 79.86711 +rs10799067 100.771484 +rs1342467 92.286095 +rs947671 21.026522 +rs6665029 73.00425 +rs6677231 25.26735 +rs6671294 152.75058 +rs7530276 146.42778 +rs12566354 122.516945 +rs10914983 193.34686 +rs12074036 192.44273 +rs12033008 70.78067 +rs12039962 183.77693 +rs1342476 61.950565 +rs6679913 131.7318 +rs6703587 181.58058 +rs10914986 181.33696 +rs728448 52.374058 +rs728447 47.42535 +rs10799073 108.81207 +rs12090691 76.55221 +rs1886340 72.35997 +rs4653061 139.58275 +rs4653062 120.76586 +rs10914991 122.037544 +rs12117895 170.4533 +rs1342475 194.37923 +rs2147384 111.788536 +rs7544557 27.45278 +rs6695510 28.803083 +rs716289 108.95569 +rs10753314 149.97383 +rs10799074 184.9501 +rs2359631 87.64986 +rs7547341 5.1670213 +rs12059222 12.2328415 +rs2093880 76.70654 +rs6425904 107.91474 +rs7548924 115.589226 +rs10914996 188.51263 +rs12132123 155.16908 +rs868501 124.81361 +rs612353 92.61688 +rs12036748 190.79884 +rs4652852 5.195567 +rs491000 65.55253 +rs506361 85.42702 +rs6425908 187.92491 +rs9787098 43.03403 +rs509788 22.220636 +rs12039039 35.800373 +rs12041637 172.85303 +rs9787040 140.03925 +rs10915010 53.220226 +rs4491030 27.456919 +rs10493062 127.62521 +rs912537 66.919495 +rs927656 145.42041 diff --git a/docs/manual.rst b/docs/manual.rst index 328ede1..424198d 100644 --- a/docs/manual.rst +++ b/docs/manual.rst @@ -47,7 +47,8 @@ Although we highly recommend users to perform high-quality QC on their own genot #. Only keep SNPs that are available in all the ancestries. #. Adjust genotype data across ancestries based on the same reference alleles. Drop non-biallelic SNPs. #. Remove SNPs that have minor allele frequency (MAF) less than 1% within each ancestry (users can change 1% with ``--maf``). -#. For single ancestry SuSiE, users have the option to perform rank inverse normalization transformation on the phenotype data. +# Users also have an option to remove ambiguous SNPs (i.e., A/T, T/A, C/G, or GC) by specifying ``--remove-ambiguous`` (Default is NOT to remove them). +#. For single ancestry SuSiE or Mega-SuSiE, users have the option to perform rank inverse normalization transformation on the phenotype data. See :func:`sushie.cli.process_raw` for these QCs' source codes. @@ -260,6 +261,16 @@ Users can use ``--keep`` command to specify a file that contains the subject IDs cd ./data/ sushie finemap --pheno EUR.pheno AFR.pheno --vcf vcf/EUR.vcf vcf/AFR.vcf --keep keep.subject --output ./test_result +14. I want to assign the prior weights for each SNP +-------------------------------------------------- + +Users can use ``--pi`` command to specify a tsv file that contains the SNP ID and their prior weights. The weights will be normalized to sum to 1 before inference. + +.. code:: bash + + cd ./data/ + sushie finemap --pheno EUR.pheno AFR.pheno --vcf vcf/EUR.vcf vcf/AFR.vcf --pi prior_weights --output ./test_result + .. _Param: Parameters @@ -314,10 +325,10 @@ Parameters - ``--L 5`` - Integer number of shared effects pre-specified. Larger number may cause slow inference. * - ``--pi`` - - Float - - 1/p - - ``--pi 0.1`` - - Prior probability for each SNP to be causal (:math:`\pi` in :ref:`Model`). Default is ``1/p`` where ``p`` is the number of SNPs in the region. It is the fixed across all ancestries. + - str + - "uniform" + - ``--pi ./prior_weights`` + - Prior probability for each SNP to be causal (:math:`\pi` in :ref:`Model`). Default is uniform (i.e., ``1/p`` where ``p`` is the number of SNPs in the region. It is the fixed across all ancestries. Alternatively, users can specify the file path that contains the prior weights for each SNP. The weights have to be positive value. The weights will be normalized to sum to 1 before inference. The file has to be a tsv file that contains two columns where the first column is the SNP ID and the second column is the prior weights. Additional columns will be ignored. For SNPs do not have prior weights in the file, it will be assigned the average value of the rest. It can be a compressed file (e.g., tsv.gz). No headers. * - ``--resid-var`` - Float - 1e-3 @@ -393,6 +404,11 @@ Parameters - False - ``--no-reorder`` - Indicator to re-order single effects based on Frobenius norm of alpha-weighted posterior mean square. Default is False (to re-order). Specify --no-reorder will store 'True' value. + * - ``--remove-ambiguous`` + - Boolean + - False + - ``--remove-ambiguous`` + - Indicator to remove ambiguous SNPs (i.e., A/T, T/A, C/G, or G/C) from the genotypes. Recommend to remove these SNPs if each ancestry data is from different studies. Default is False (do not remove). Specify --remove-ambiguous will store 'True' value. * - ``--meta`` - Boolean - False diff --git a/sushie/cli.py b/sushie/cli.py index 1fcc11e..4027854 100755 --- a/sushie/cli.py +++ b/sushie/cli.py @@ -63,11 +63,11 @@ def _keep_file_subjects( old_pheno_num = pheno.shape[0] bed = bed[fam.iid.isin(keep_subject).values, :] - fam = fam.loc[fam.iid.isin(keep_subject)] - pheno = pheno.loc[pheno.iid.isin(keep_subject)] + fam = fam.loc[fam.iid.isin(keep_subject)].reset_index(drop=True) + pheno = pheno.loc[pheno.iid.isin(keep_subject)].reset_index(drop=True) if covar is not None: - covar = covar.loc[covar.iid.isin(keep_subject)] + covar = covar.loc[covar.iid.isin(keep_subject)].reset_index(drop=True) del_fam_num = old_fam_num - fam.shape[0] del_pheno_num = old_pheno_num - pheno.shape[0] @@ -95,13 +95,19 @@ def _drop_na_subjects(rawData: io.RawData) -> Tuple[io.RawData, int]: ) del_idx = jnp.logical_or(del_idx, covar_del) - (drop_idx,) = jnp.where(del_idx) - fam = fam.drop(drop_idx) - pheno = pheno.drop(drop_idx) - bed = jnp.delete(bed, drop_idx, 0) + (drop_idx_name,) = jnp.where(del_idx) + + # ambiguous_idx is the positional index, but drop use the index label, so we need to convert it + fam_drop_idx = fam.index[drop_idx_name] + pheno_drop_idx = pheno.index[drop_idx_name] + + fam = fam.drop(fam_drop_idx).reset_index(drop=True) + pheno = pheno.drop(pheno_drop_idx).reset_index(drop=True) + bed = jnp.delete(bed, drop_idx_name, 0) if covar is not None: - covar = covar.drop(drop_idx) + covar_drop_idx = pheno.index[drop_idx_name] + covar = covar.drop(covar_drop_idx).reset_index(drop=True) rawData = rawData._replace( fam=fam, @@ -110,6 +116,26 @@ def _drop_na_subjects(rawData: io.RawData) -> Tuple[io.RawData, int]: covar=covar, ) + return rawData, len(drop_idx_name) + + +def _remove_ambiguous_geno(rawData: io.RawData) -> Tuple[io.RawData, int]: + bim, _, bed, _, _ = rawData + + ambiguous_snps = ["AT", "TA", "CG", "GC"] + ambiguous_idx = jnp.where((bim.A0.values + bim.A1.values).isin(ambiguous_snps))[0] + + # ambiguous_idx is the positional index, but drop use the index label, so we need to convert it + drop_idx = bim.index[ambiguous_idx] + bim = bim.drop(drop_idx).reset_index(drop=True) + + bed = jnp.delete(bed, ambiguous_idx, 1) + + rawData = rawData._replace( + bim=bim, + bed=bed, + ) + return rawData, len(drop_idx) @@ -118,7 +144,10 @@ def _remove_dup_geno(rawData: io.RawData) -> Tuple[io.RawData, int]: (dup_idx,) = jnp.where(bim.snp.duplicated().values) - bim = bim.drop(dup_idx) + # dup_idx is the positional index, but drop use the index label, so we need to convert it + dup_idx_name = bim.index[dup_idx] + + bim = bim.drop(dup_idx_name).reset_index(drop=True) bed = jnp.delete(bed, dup_idx, 1) rawData = rawData._replace( @@ -135,7 +164,10 @@ def _impute_geno(rawData: io.RawData) -> Tuple[io.RawData, int, int]: # if we observe SNPs have nan value for all participants (although not likely), drop them (del_idx,) = jnp.where(jnp.all(jnp.isnan(bed), axis=0)) - bim = bim.drop(del_idx) + # del_idx is the positional index, but drop use the index label, so we need to convert it + del_idx_name = bim.index[del_idx] + + bim = bim.drop(del_idx_name).reset_index(drop=True) bed = jnp.delete(bed, del_idx, 1) # if we observe SNPs that partially have nan value, impute them with column mean @@ -233,7 +265,10 @@ def _allele_check( def _prepare_cv( - geno: List[jnp.ndarray], pheno: List[jnp.ndarray], cv_num: int, seed: int + geno: List[jnp.ndarray], + pheno: List[jnp.ndarray], + cv_num: int, + seed: int, ) -> List[io.CVData]: rng_key = random.PRNGKey(seed) n_pop = len(geno) @@ -284,7 +319,7 @@ def _prepare_cv( return cv_data -def _run_cv(args, cv_data) -> List[List[jnp.ndarray]]: +def _run_cv(args, cv_data, pi) -> List[List[jnp.ndarray]]: n_pop = len(cv_data[0].train_geno) # create a list to store future estimated y value est_y = [jnp.array([])] * n_pop @@ -298,7 +333,7 @@ def _run_cv(args, cv_data) -> List[List[jnp.ndarray]]: no_scale=args.no_scale, no_regress=args.no_regress, no_update=args.no_update, - pi=args.pi, + pi=pi, resid_var=args.resid_var, effect_var=args.effect_var, rho=args.rho, @@ -330,7 +365,7 @@ def _run_cv(args, cv_data) -> List[List[jnp.ndarray]]: def parameter_check( args, -) -> Tuple[int, pd.DataFrame, List[str], List[str], Callable]: +) -> Tuple[int, pd.DataFrame, List[str], pd.DataFrame, List[str], Callable]: """The function to process raw phenotype, genotype, covariates data across ancestries. Args: @@ -342,6 +377,7 @@ def parameter_check( #. an integer to indicate how many ancestries, #. a DataFrame that contains ancestry index (can be none), #. a list that contains subject ID that fine-mapping performs on. + #. a DataFrame that contains prior probability for each SNP to be causal. #. a list of genotype data paths (:py:obj:`List[str]`), #. genotype read-in function (:py:obj:`Callable`). @@ -484,6 +520,32 @@ def parameter_check( ) keep_subject = df_keep[0].values.tolist() + if args.pi != "uniform": + log.logger.info( + "Detect file that contains prior weights for each SNP to be causal." + ) + pi = pd.read_csv(args.pi, header=None, sep="\t") + if pi.shape[0] == 0: + raise ValueError( + "No prior weights are listed in the prior file. Check the source." + ) + + if pi.shape[1] < 2: + raise ValueError( + "The prior file has less than 2 columns. It has to be at least two columns." + + " The first column is the SNP ID, and the second column the prior probability." + ) + + if pi.shape[1] > 2: + log.logger.debug( + "The prior file has more than 2 columns. Will only use the first two columns." + ) + + pi = pi.iloc[:, 0:2] + pi.columns = ["snp", "pi"] + else: + pi = pd.DataFrame() + if args.seed <= 0: raise ValueError( "The seed specified for randomization is invalid. Choose a positive integer." @@ -507,12 +569,14 @@ def parameter_check( "The number of ancestry is 1, but --meta or --mega is specified. Will skip meta or mega SuSiE." ) - return n_pop, ancestry_index, keep_subject, geno_path, geno_func + return n_pop, ancestry_index, keep_subject, pi, geno_path, geno_func def process_raw( rawData: List[io.RawData], keep_subject: List[str], + pi: pd.DataFrame, + remove_ambiguous: bool, maf: float, rint: bool, no_regress: bool, @@ -531,6 +595,8 @@ def process_raw( Args: rawData: Raw data for phenotypes, genotypes, covariates across ancestries. keep_subject: The DataFrame that contains subject ID that fine-mapping performs on. + pi: The DataFrame that contains prior weights for each SNP to be causal. + remove_ambiguous: The indicator whether to remove ambiguous SNPs. maf: The minor allele frequency threshold to filter the genotypes. rint: The indicator whether to perform rank inverse normalization on each phenotype data. no_regress: The indicator whether to regress genotypes on covariates. @@ -605,15 +671,20 @@ def process_raw( + " Check the source." ) - old_snp_num = rawData[idx].bim.shape[0] - # remove duplicates SNPs based on rsid even though we suggest users to do some QC on this - rawData[idx], del_num = _remove_dup_geno(rawData[idx]) + # remove ambiguous SNPs (i.e., A/T, T/A, C/G, G/C pairs) in genotype data + if remove_ambiguous: + old_snp_num = rawData[idx].bim.shape[0] + rawData[idx], del_num = _remove_ambiguous_geno(rawData[idx]) - if del_num != 0: - log.logger.debug( - f"Ancestry {idx + 1}: Drop {del_num} out of {old_snp_num} SNPs because of duplicates in the rs ID" - + " in genotype data." - ) + if del_num == old_snp_num: + raise ValueError( + f"Ancestry {idx + 1}: All SNPs are ambiguous in genotype data. Check the source." + ) + + if del_num != 0: + log.logger.debug( + f"Ancestry {idx + 1}: Drop {del_num} ambiguous SNPs out of {old_snp_num} in genotype data." + ) old_snp_num = rawData[idx].bim.shape[0] # impute genotype data even though we suggest users to impute the genotypes beforehand @@ -650,9 +721,14 @@ def process_raw( f"Ancestry {idx + 1}: Drop {del_num} out of {old_snp_num} SNPs because of maf threshold at {maf}." ) - if rawData[idx].bim.shape[0] == 0: - raise ValueError( - f"Ancestry {idx + 1}: no SNPs left after QC. Check the source." + old_snp_num = rawData[idx].bim.shape[0] + # remove duplicates SNPs based on rsid even though we suggest users to do some QC on this + rawData[idx], del_num = _remove_dup_geno(rawData[idx]) + + if del_num != 0: + log.logger.debug( + f"Ancestry {idx + 1}: Drop {del_num} out of {old_snp_num} SNPs because of duplicates in the rs ID" + + " in genotype data." ) # reset index and add index column to all dataset for future inter-ancestry or inter-dataset processing @@ -693,11 +769,35 @@ def process_raw( else: snps = rawData[0].bim + # remove non-biallelic SNPs across ancestries + if n_pop > 1: + for idx in range(1, n_pop): + _, _, remove_idx = _allele_check( + snps["a0_1"].values, + snps["a1_1"].values, + snps[f"a0_{idx + 1}"].values, + snps[f"a1_{idx + 1}"].values, + ) + + if len(remove_idx) != 0: + snps = snps.drop(index=remove_idx).reset_index(drop=True) + log.logger.debug( + f"Ancestry{idx + 1} has {len(remove_idx)} alleles that" + + "couldn't match to ancestry 1 and couldn't be flipped. Will remove these SNPs." + ) + + if snps.shape[0] == 0: + raise ValueError( + f"Ancestry {idx + 1} has none of correct or flippable SNPs from ancestry 1. Check the source.", + ) + + snps = snps.reset_index(drop=True) + # find flipped reference alleles across ancestries flip_idx = [] if n_pop > 1: for idx in range(1, n_pop): - correct_idx, tmp_flip_idx, wrong_idx = _allele_check( + correct_idx, tmp_flip_idx, _ = _allele_check( snps["a0_1"].values, snps["a1_1"].values, snps[f"a0_{idx + 1}"].values, @@ -712,25 +812,31 @@ def process_raw( # save the index for future swapping flip_idx.append(tmp_flip_idx) - if len(wrong_idx) != 0: - snps = snps.drop(index=wrong_idx) - log.logger.debug( - f"Ancestry{idx + 1} has {len(wrong_idx)} alleles that couldn't be flipped. Will remove these SNPs." - ) - - if snps.shape[0] == 0: - raise ValueError( - f"Ancestry {idx + 1} has none of correct or flippable SNPs from ancestry 1. Check the source.", - ) # drop unused columns snps = snps.drop( columns=[f"a0_{idx + 1}", f"a1_{idx + 1}", f"pos_{idx + 1}"] ) - # rename columns for better indexing in the future + + # rename columns for better indexing in the future snps = snps.reset_index().rename( columns={"index": "SNPIndex", "a0_1": "a0", "a1_1": "a1", "pos_1": "pos"} ) + if pi.shape[0] != 0: + # append prior weights to the snps + snps = pd.merge(snps, pi, how="left", on="snp") + nan_count = snps["pi"].isna().sum() + if nan_count > snps.shape[0] * 0.25: + log.logger.warning( + "More than 25% of SNPs have missing prior weights. Will replace them with the mean value of the rest." + ) + # if the column pi has nan value, replace it with the mean value of the rest of the column + snps["pi"] = snps["pi"].fillna(snps["pi"].mean()) + pi = jnp.array(snps["pi"].values) + else: + snps["pi"] = jnp.ones(snps.shape[0]) / float(snps.shape[0]) + pi = None + geno = [] pheno = [] covar = [] @@ -749,6 +855,8 @@ def process_raw( tmp_geno = tmp_geno[common_ind_id, :][:, common_snp_id] # flip genotypes for bed files starting second ancestry + # flip index is the positional index based on snps data frame, so we have to subset genotype + # data based on the common snps (i.e., snps data frame). if idx > 0 and len(flip_idx[idx - 1]) != 0: tmp_geno[:, flip_idx[idx - 1]] = 2 - tmp_geno[:, flip_idx[idx - 1]] @@ -776,7 +884,7 @@ def process_raw( else: data_covar = covar - regular_data = io.CleanData(geno=geno, pheno=pheno, covar=data_covar) + regular_data = io.CleanData(geno=geno, pheno=pheno, covar=data_covar, pi=pi) name_ancestry = "ancestry" if n_pop == 1 else "ancestries" @@ -816,6 +924,7 @@ def process_raw( geno=[mega_geno], pheno=[mega_pheno], covar=None, + pi=pi, ) return snps, regular_data, mega_data, cv_data @@ -895,7 +1004,7 @@ def sushie_wrapper( no_scale=args.no_scale, no_regress=args.no_regress, no_update=args.no_update, - pi=args.pi, + pi=data.pi, resid_var=resid_var, effect_var=effect_var, rho=None, @@ -931,7 +1040,7 @@ def sushie_wrapper( no_scale=args.no_scale, no_regress=args.no_regress, no_update=args.no_update, - pi=args.pi, + pi=data.pi, resid_var=resid_var, effect_var=effect_var, rho=rho, @@ -985,8 +1094,8 @@ def sushie_wrapper( log.logger.info( f"Start {args.cv_num}-fold cross validation as --cv is specified " ) - cv_res = _run_cv(args, cv_data) - sample_size = [idx.shape[0] for idx in data.geno] + cv_res = _run_cv(args, cv_data, data.pi) + sample_size = jnp.squeeze(tmp_result.sample_size) io.output_cv(cv_res, sample_size, output, args.trait, args.compress) return None @@ -1007,7 +1116,7 @@ def run_finemap(args): config.update("jax_platform_name", args.platform) - n_pop, ancestry_index, keep_subject, geno_path, geno_func = parameter_check( + n_pop, ancestry_index, keep_subject, pi, geno_path, geno_func = parameter_check( args ) @@ -1023,6 +1132,8 @@ def run_finemap(args): snps, regular_data, mega_data, cv_data = process_raw( rawData, keep_subject, + pi, + args.remove_ambiguous, args.maf, args.rint, args.no_regress, @@ -1184,19 +1295,27 @@ def build_finemap_parser(subp): type=int, help=( "Integer number of shared effects pre-specified.", - " Default is 5. Larger number may cause slow inference.", + " Default is 10. Larger number may cause slow inference.", ), ) # fine-map prior options finemap.add_argument( "--pi", - default=None, - type=float, + default="uniform", + type=str, help=( "Prior probability for each SNP to be causal.", - " Default is 1/p where p is the number of SNPs in the region.", + " Default is uniform (i.e., 1/p where p is the number of SNPs in the region.", " It is the fixed across all ancestries.", + " Alternatively, users can specify the file path that contains the prior weights for each SNP.", + " The weights have to be positive value.", + " The weights will be normalized to sum to 1 before inference.", + " The file has to be a tsv file that contains two columns where the", + " first column is the SNP ID and the second column is the prior weights.", + " Additional columns will be ignored.", + " For SNPs do not have prior weights in the file, it will be assigned the average value of the rest.", + " It can be a compressed file (e.g., tsv.gz). No headers.", ), ) @@ -1377,6 +1496,18 @@ def build_finemap_parser(subp): ), ) + finemap.add_argument( + "--remove-ambiguous", + default=False, + action="store_true", + help=( + "Indicator to remove ambiguous SNPs (i.e., A/T, T/A, C/G, or G/C pairs) from the genotypes.", + " Recommend to remove these SNPs if each ancestry data is from different studies.", + " Default is False (not to remove).", + " Specify --remove-ambiguous will store 'True' value.", + ), + ) + # I/O option finemap.add_argument( "--meta", diff --git a/sushie/infer.py b/sushie/infer.py index 6fe3aca..3d9ae19 100644 --- a/sushie/infer.py +++ b/sushie/infer.py @@ -160,7 +160,7 @@ def infer_sushie( no_scale: bool = False, no_regress: bool = False, no_update: bool = False, - pi: Array = None, + pi: ArrayLike = None, resid_var: utils.ListFloatOrNone = None, effect_var: utils.ListFloatOrNone = None, rho: utils.ListFloatOrNone = None, @@ -262,10 +262,22 @@ def infer_sushie( "The maximum selected number of SNPs is too small thus may miss true positives. Choose a positive integer." ) - if pi is not None and (pi >= 1 or pi <= 0): - raise ValueError( - f"Pi prior ({pi}) is not a probability (0-1). Specify a valid pi prior." - ) + if pi is not None: + if not (pi > 0).all(): + raise ValueError( + "Prior probability/weights contain negative value. Specify a valid pi prior." + ) + + if pi.shape[0] != Xs[0].shape[1]: + raise ValueError( + f"Prior probability/weights ({pi.shape[0]}) does not match the number of SNPs ({Xs[0].shape[1]})." + ) + + if jnp.sum(pi) > 1: + log.logger.debug( + "Prior probability/weights sum to more than 1. Will normalize to sum to 1." + ) + pi = float(pi / jnp.sum(pi)) # first regress out covariates if there are any, then scale the genotype and phenotype if covar is not None: @@ -457,11 +469,11 @@ def infer_sushie( ) elbo_last = elbo_tracker[o_iter] elbo_tracker = jnp.append(elbo_tracker, elbo_cur) - elbo_increase = elbo_cur < elbo_last and ( - not jnp.isclose(elbo_cur, elbo_last, atol=1e-8) + elbo_increase = not ( + elbo_cur < elbo_last and (not jnp.isclose(elbo_cur, elbo_last, atol=1e-8)) ) - if elbo_increase or jnp.isnan(elbo_cur): + if (not elbo_increase) or jnp.isnan(elbo_cur): log.logger.warning( f"Optimization concludes after {o_iter + 1} iterations." + f" ELBO decreases. Final ELBO score: {elbo_cur}. Return last iteration's results." @@ -471,7 +483,6 @@ def infer_sushie( ) priors = prev_priors posteriors = prev_posteriors - elbo_increase = False break decimal_digit = len(str(min_tol)) - str(min_tol).find(".") - 1 diff --git a/sushie/io.py b/sushie/io.py index b249f25..5f8ac46 100644 --- a/sushie/io.py +++ b/sushie/io.py @@ -34,7 +34,7 @@ class CVData(NamedTuple): - """Define the raw data object for the future inference. + """Define the cross validation data object. Attributes: train_geno: genotype data for training SuShiE weights. @@ -51,22 +51,24 @@ class CVData(NamedTuple): class CleanData(NamedTuple): - """Define the raw data object for the future inference. + """Define the clean data object ready for the future inference. Attributes: geno: actual genotype data. pheno: phenotype data. covar: covariate needed to be adjusted in the inference. + pi: prior weights for each SNP to be causal. """ geno: List[Array] pheno: List[Array] covar: utils.ListArrayOrNone + pi: utils.ListArrayOrNone class RawData(NamedTuple): - """Define the raw data object for the future inference. + """Define the raw data object for the future data cleaning. Attributes: bim: SNP information data. @@ -142,12 +144,12 @@ def read_data( tmp_covar = covar if index_file: tmp_pt = ancestry_index.loc[ancestry_index[1] == (idx + 1)][0] - tmp_fam = fam.loc[fam.iid.isin(tmp_pt)] + tmp_fam = fam.loc[fam.iid.isin(tmp_pt)].reset_index(drop=True) tmp_bed = bed[fam.iid.isin(tmp_pt).values, :] - tmp_pheno = pheno.loc[pheno.iid.isin(tmp_pt)] + tmp_pheno = pheno.loc[pheno.iid.isin(tmp_pt)].reset_index(drop=True) if covar_paths is not None: - tmp_covar = covar.loc[covar.iid.isin(tmp_pt)] + tmp_covar = covar.loc[covar.iid.isin(tmp_pt)].reset_index(drop=True) else: tmp_covar = None @@ -218,8 +220,8 @@ def read_vcf(path: str) -> Tuple[pd.DataFrame, pd.DataFrame, Array]: for var in vcf: # var.ALT is a list of alternative allele bim_list.append([var.CHROM, var.ID, var.POS, var.ALT[0], var.REF]) - var.gt_types = jnp.where(var.gt_types == 3, jnp.nan, var.gt_types) - tmp_bed = 2 - var.gt_types + tmp_gt_types = jnp.where(var.gt_types == 3, jnp.nan, var.gt_types) + tmp_bed = 2 - tmp_gt_types bed_list.append(tmp_bed) bim = pd.DataFrame(bim_list, columns=["chrom", "snp", "pos", "a0", "a1"]) From b05684b32ed77eb13c1839e1dd3780e000da0e10 Mon Sep 17 00:00:00 2001 From: Zeyun Date: Mon, 8 Jul 2024 01:34:56 -0700 Subject: [PATCH 3/4] small fix --- sushie/cli.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sushie/cli.py b/sushie/cli.py index 4027854..9fd3a5f 100755 --- a/sushie/cli.py +++ b/sushie/cli.py @@ -672,7 +672,8 @@ def process_raw( ) # remove ambiguous SNPs (i.e., A/T, T/A, C/G, G/C pairs) in genotype data - if remove_ambiguous: + # we just need to remove it for the first ancestry, later ancestries will be merged into first ancestry + if remove_ambiguous and idx == 0: old_snp_num = rawData[idx].bim.shape[0] rawData[idx], del_num = _remove_ambiguous_geno(rawData[idx]) From b7955f030778a29f79c092c1dbd1677c4edf1ad5 Mon Sep 17 00:00:00 2001 From: Zeyun Date: Mon, 8 Jul 2024 02:07:35 -0700 Subject: [PATCH 4/4] enhance codes --- sushie/infer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sushie/infer.py b/sushie/infer.py index 3d9ae19..349074b 100644 --- a/sushie/infer.py +++ b/sushie/infer.py @@ -469,11 +469,11 @@ def infer_sushie( ) elbo_last = elbo_tracker[o_iter] elbo_tracker = jnp.append(elbo_tracker, elbo_cur) - elbo_increase = not ( - elbo_cur < elbo_last and (not jnp.isclose(elbo_cur, elbo_last, atol=1e-8)) + elbo_increase = elbo_cur >= elbo_last or jnp.isclose( + elbo_cur, elbo_last, atol=1e-8 ) - if (not elbo_increase) or jnp.isnan(elbo_cur): + if not elbo_increase: log.logger.warning( f"Optimization concludes after {o_iter + 1} iterations." + f" ELBO decreases. Final ELBO score: {elbo_cur}. Return last iteration's results."