Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

change the pdb_paths working style and support for loading both local… #214

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
174 changes: 98 additions & 76 deletions graphein/ml/datasets/torch_geometric_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ def __init__(
self,
root: str,
name: str,
pdb_paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
graph_label_map: Optional[Dict[str, torch.Tensor]] = None,
node_label_map: Optional[Dict[str, torch.Tensor]] = None,
chain_selection_map: Optional[Dict[str, List[str]]] = None,
pdb_paths: Optional[List[str]] = [],
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning for using empty lists as the default arg?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty lists can add together even if they are empty, while None can't. So we can skip some if for different statements of the user pass pdb_paths or pdb_codes or uniprot_ids, and just merge them into self.structures, which is used at process func and it works like os.listdir(self.raw_dir).

As for some potential bugs, i'm really not sure would this will cause some bugs as i use empty list instead of None.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be None

https://stackoverflow.com/questions/366422/what-is-the-pythonic-way-to-avoid-default-parameters-that-are-empty-lists

If you want to retain the behaviour inside the object, you could do:

if working_list is None: working_list = []

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and actually i have done that in the latest commit

pdb_codes: Optional[List[str]] = [],
uniprot_ids: Optional[List[str]] = [],
graph_labels: Optional[List[torch.Tensor]] = None,
node_labels: Optional[List[torch.Tensor]] = None,
chain_selections: Optional[List[str]] = None,
graphein_config: ProteinGraphConfig = ProteinGraphConfig(),
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor(
src_format="nx", dst_format="pyg"
Expand All @@ -73,13 +73,13 @@ def __init__(
:type root: str
:param name: Name of the dataset. Will be saved to ``data_$name.pt``.
:type name: str
:param pdb_paths: List of full path of pdb files to load. Defaults to ``None``.
:param pdb_paths: List of full path of pdb files to load. Defaults to ``List``.
:type pdb_paths: Optional[List[str]], optional
:param pdb_codes: List of PDB codes to download and parse from the PDB.
Defaults to None.
Defaults to List.
:type pdb_codes: Optional[List[str]], optional
:param uniprot_ids: List of Uniprot IDs to download and parse from
Alphafold Database. Defaults to ``None``.
Alphafold Database. Defaults to ``List``.
:type uniprot_ids: Optional[List[str]], optional
:param graph_label_map: Dictionary mapping PDB/Uniprot IDs to
graph-level labels. Defaults to ``None``.
Expand Down Expand Up @@ -130,54 +130,56 @@ def __init__(
self.pdb_codes = (
[pdb.lower() for pdb in pdb_codes]
if pdb_codes is not None
else None
else []
)
self.uniprot_ids = (
[up.upper() for up in uniprot_ids]
if uniprot_ids is not None
else None
else []
)

self.pdb_paths = pdb_paths
if self.pdb_paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
# Use local saved pdb_files instead of download or move them to self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
self.structures = [
# make sure root path is unique
if self.pdb_paths:
# add pdb_paths' name into self.structure
self.pdb_paths_name = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])

if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
self.af_version = af_version
# if root pdb_path is not unique raise error since we will save all pdb into this root pdb_path and take it as the self.raw_dir
if len(set([os.path.split(pdb_path)[0] for pdb_path in pdb_paths])) != 1:
raise ValueError("pdb_paths should have only one root path not so much!")
else:
self.pdb_paths_name = []

self.structures = list(set(self.pdb_codes + self.uniprot_ids + self.pdb_paths_name)) # remove some pdb_codes is in pdb_path and loaded repeately
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this should be a set operation. With chain selections you may want to have e.g. 3eiy_A and 3eiy_B as different examples in your dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, i guess it would'n make a difference at chain selection, this set operation is to drop duplicate in the result list of pdb_codes + uniprot_ids + paths_name. As you can see, local dir may have some pdb files like 10gs.pdb, and if pdb_codes also have 10gs to download, and self.structures would contain double 10gs and so the finial dataset object will have duplicate Data object.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It becomes a problem here though (L283), no?

    def process(self):
        """Process structures into PyG format and save to disk."""
        # Read data into huge `Data` list.
        structure_files = [
            f"{self.raw_dir}/{pdb}.pdb" for pdb in self.structures
        ]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, i guess not.
code like below in the tests/ml/test.ipynb

from graphein.ml.datasets import InMemoryProteinGraphDataset

local_dir = "../protein/test_data"
pdb_paths = [osp.join(local_dir, pdb_file) for pdb_file in os.listdir(local_dir) if pdb_file.endswith(".pdb")]

ds = InMemoryProteinGraphDataset(root = "../protein/test_data/InMemoryProteinGraphDataset",
                    name = "InMemoryProteinGraphDataset_test",
                    pdb_paths=pdb_paths,
                    pdb_codes=["10gs"],
                    uniprot_ids=["A0A6J1BG53", "A0A6P5Z5F7"],
                    af_version=3)

and before running it:
image

then run it :
image

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you see what happens with:

from graphein.ml.dataset import InMemoryProteinGraphDataset

ds = InMemoryProteinGraphDataset(root = ""../protein.test_data/InMemoryProteinGraphDataset", pdb_paths=pdb_paths, pdb_codes = ["4hhb", "4hhb"], chain_selection=["A","B"])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, i'll try later

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you see what happens with:

from graphein.ml.dataset import InMemoryProteinGraphDataset

ds = InMemoryProteinGraphDataset(root = ""../protein.test_data/InMemoryProteinGraphDataset", pdb_paths=pdb_paths, pdb_codes = ["4hhb", "4hhb"], chain_selection=["A","B"])

image

and why i ['4hhs', '4hhs']

image

I guess this may need lots of change~


# Labels & Chains
if graph_labels is not None:
self.graph_label_map = dict(enumerate(graph_labels))
else:
self.graph_label_map = None

if node_labels is not None:
self.node_label_map = dict(enumerate(node_labels))
else:
self.node_label_map = None
if chain_selections is not None:
self.chain_selection_map = dict(enumerate(chain_selections))
else:
self.chain_selection_map = None
self.validate_input()
self.bad_pdbs: List[
str
] = [] # list of pdb codes that failed to download

# Labels & Chains
self.graph_label_map = graph_label_map
self.node_label_map = node_label_map
self.chain_selection_map = chain_selection_map

# Configs
self.config = graphein_config
self.graph_format_convertor = graph_format_convertor
self.graph_transformation_funcs = graph_transformation_funcs
self.pdb_transform = pdb_transform
self.num_cores = num_cores
self.af_version = af_version

super().__init__(
root,
transform=transform,
Expand All @@ -200,10 +202,34 @@ def processed_file_names(self) -> List[str]:
@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
# replace raw dir with user local pdb_path; so pdb_paths should be located in the same place
self.pdb_path, _ = os.path.split(self.pdb_paths[0])
return self.pdb_path
else:
return os.path.join(self.root, "raw")

def validate_input(self):
if self.graph_label_map is not None:
assert len(self.structures) == len(
self.graph_label_map
), "Number of proteins and graph labels must match"
if self.node_label_map is not None:
assert len(self.structures) == len(
self.node_label_map
), "Number of proteins and node labels must match"
if self.chain_selection_map is not None:
assert len(self.structures) == len(
self.chain_selection_map
), "Number of proteins and chain selections must match"
assert len(
{
f"{pdb}_{chain}"
for pdb, chain in zip(
self.structures, self.chain_selection_map
)
}
) == len(self.structures), "Duplicate protein/chain combinations"

def download(self):
"""Download the PDB files from RCSB or Alphafold."""
self.config.pdb_dir = Path(self.raw_dir)
Expand All @@ -225,6 +251,7 @@ def download(self):
for pdb in set(self.pdb_codes)
if not os.path.exists(Path(self.raw_dir) / f"{pdb}.pdb")
]
print("downloading uniprotids")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using log would be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohhhhhhhhhhh, too sry for these print, forget to remove them, XD.

I'll remove them today

if self.uniprot_ids:
[
download_alphafold_structure(
Expand All @@ -237,6 +264,7 @@ def download(self):
]

def __len__(self) -> int:
"""Returns length of data set (number of structures)."""
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should return the number of examples (not just the number of structures for the multiple chain reason I mentioned previously)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

return len(self.structures)

def transform_pdbs(self):
Expand Down Expand Up @@ -327,15 +355,12 @@ class ProteinGraphDataset(Dataset):
def __init__(
self,
root: str,
pdb_paths: Optional[List[str]] = None,
pdb_codes: Optional[List[str]] = None,
uniprot_ids: Optional[List[str]] = None,
# graph_label_map: Optional[Dict[str, int]] = None,
pdb_paths: Optional[List[str]] = [],
pdb_codes: Optional[List[str]] = [],
uniprot_ids: Optional[List[str]] = [],
graph_labels: Optional[List[torch.Tensor]] = None,
node_labels: Optional[List[torch.Tensor]] = None,
chain_selections: Optional[List[str]] = None,
# node_label_map: Optional[Dict[str, int]] = None,
# chain_selection_map: Optional[Dict[str, List[str]]] = None,
graphein_config: ProteinGraphConfig = ProteinGraphConfig(),
graph_format_convertor: GraphFormatConvertor = GraphFormatConvertor(
src_format="nx", dst_format="pyg"
Expand All @@ -356,22 +381,20 @@ def __init__(
:param root: Root directory where the dataset should be saved.
:type root: str
:param pdb_paths: List of full path of pdb files to load. Defaults to ``None``.
:param pdb_paths: List of full path of pdb files to load. Defaults to ``List``.
:type pdb_paths: Optional[List[str]], optional
:param pdb_codes: List of PDB codes to download and parse from the PDB.
Defaults to ``None``.
Defaults to ``List``.
:type pdb_codes: Optional[List[str]], optional
:param uniprot_ids: List of Uniprot IDs to download and parse from
Alphafold Database. Defaults to ``None``.
Alphafold Database. Defaults to ``List``.
:type uniprot_ids: Optional[List[str]], optional
:param graph_label_map: Dictionary mapping PDB/Uniprot IDs to
graph-level labels. Defaults to ``None``.
:type graph_label_map: Optional[Dict[str, Tensor]], optional
:param node_label_map: Dictionary mapping PDB/Uniprot IDs to node-level
labels. Defaults to ``None``.
:type node_label_map: Optional[Dict[str, torch.Tensor]], optional
:param chain_selection_map: Dictionary mapping, defaults to ``None``.
:type chain_selection_map: Optional[Dict[str, List[str]]], optional
:param graph_labels: List mapping to self.structures by index to graph-level labels. Defaults to ``None``.
:type graph_labels: Optional[List[torch.Tensor]], optional
:param node_labels: List mapping to self.structures by index to node-level labels. Defaults to ``None``.
:type node_labels: Optional[List[torch.Tensor]], optional
:param chain_selections: List mapping to self.structures by index to chain selection, defaults to ``None``.
:type chain_selections: Optional[List[str]], optional
:param graphein_config: Protein graph construction config, defaults to
``ProteinGraphConfig()``.
:type graphein_config: ProteinGraphConfig, optional
Expand Down Expand Up @@ -412,34 +435,32 @@ def __init__(
self.pdb_codes = (
[pdb.lower() for pdb in pdb_codes]
if pdb_codes is not None
else None
else []
)
self.uniprot_ids = (
[up.upper() for up in uniprot_ids]
if uniprot_ids is not None
else None
else []
)
self.pdb_paths = pdb_paths
if self.pdb_paths is None:
if self.pdb_codes and self.uniprot_ids:
self.structures = self.pdb_codes + self.uniprot_ids
elif self.pdb_codes:
self.structures = pdb_codes
elif self.uniprot_ids:
self.structures = uniprot_ids
# Use local saved pdb_files instead of download or move them to self.root/raw dir
else:
if isinstance(self.pdb_paths, list):
self.structures = [
# make sure root path is unique
if self.pdb_paths:
# add pdb_paths' name into self.structure
self.pdb_paths_name = [
os.path.splitext(os.path.split(pdb_path)[-1])[0]
for pdb_path in self.pdb_paths
]
self.pdb_path, _ = os.path.split(self.pdb_paths[0])

# Labels & Chains
# if root pdb_path is not unique raise error since we will save all pdb into this root pdb_path and take it as the self.raw_dir
if len(set([os.path.split(pdb_path)[0] for pdb_path in pdb_paths])) != 1:
raise ValueError("pdb_paths should have only one root path not so much!")
else:
self.pdb_paths_name = []

self.structures = list(set(self.pdb_codes + self.uniprot_ids + self.pdb_paths_name)) # remove some pdb_codes is in pdb_path and loaded repeately

self.examples: Dict[int, str] = dict(enumerate(self.structures))

# Labels & Chains
if graph_labels is not None:
self.graph_label_map = dict(enumerate(graph_labels))
else:
Expand All @@ -460,9 +481,9 @@ def __init__(
# Configs
self.config = graphein_config
self.graph_format_convertor = graph_format_convertor
self.num_cores = num_cores
self.pdb_transform = pdb_transform
self.graph_transformation_funcs = graph_transformation_funcs
self.pdb_transform = pdb_transform
self.num_cores = num_cores
self.af_version = af_version
super().__init__(
root,
Expand Down Expand Up @@ -492,8 +513,10 @@ def processed_file_names(self) -> List[str]:

@property
def raw_dir(self) -> str:
if self.pdb_paths is not None:
return self.pdb_path # replace raw dir with user local pdb_path
if self.pdb_paths:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think it would be useful to allow users to choose a path for raw_dir when initialising the Dataset objects.

Copy link
Contributor Author

@1511878618 1511878618 Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree.
If we simply change self.raw_dir instead of self.pdb_paths, where the former is a folder dir the latter is a list containing pdb_file dir, i guess we will use os.listdir to get local pdb files dir.
And the question is if os.listdir in the func, and then the order of self.structure maybe hard to match the order of graph_labels and node_labels, since we match the labels by index of list, i guess.
image

image

image

I'm not sure about this, i prefer to dict, which key is the names like {'10gs':0} would be better than {0:0}. And then we could just change the raw_dir and os.listdir and get a list of pdb file dir containing both local and downloaded pdb files, and process and assign each pdb files with their node_graph_label or chain_selection or graph_label by their name (remove root path and suffix like ./test/10gs.pdb -> 10gs) not by the enumunated index (which i think it is hard to match the correct order with pdb files when passing the graph_labels)

This description is not very clear, i'll try to make it clear later...

If something wrong in my understanding, please tell me 😄 , i'm still reading and learning your code lol. It's really a pythonic code, i learnt a lot 👍 👍

Copy link
Owner

@a-r-j a-r-j Sep 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we simply change self.raw_dir instead of self.pdb_paths, where the former is a folder dir the latter is a list containing pdb_file dir, i guess we will use `os.listdir`` to get local pdb files dir.

I don't think this is the best idea. I think being explicit about the paths users want to use is best. For instance, people may want to use only a subset of their dataset (rather than everything in the directory - e.g. imagine where you want to keep all your pdb files together but train/test on different subsets). It also has the potential problem with hidden files like .DS_Store etc. You're also completely right about the matching the list to node labels etc.

I'm not sure about this, i prefer to dict, which key is the names like {'10gs':0} would be better than {0:0}

This was my initial implementation. However, this ran into the problem where you may have different examples in your dataset drawn from different chains of the same PDB. E.g. imagine you have 3eiy_A and 3eiy_B with different labels. The current implementation allows for this, whereas indexing on the PDB name does not.

If something wrong in my understanding, please tell me 😄 , i'm still reading and learning your code lol. It's really a pythonic code, i learnt a lot 👍 👍

Thanks!! Me too!

# replace raw dir with user local pdb_path; so pdb_paths should be located in the same place
self.pdb_path, _ = os.path.split(self.pdb_paths[0])
return self.pdb_path
else:
return os.path.join(self.root, "raw")

Expand Down Expand Up @@ -610,7 +633,6 @@ def divide_chunks(l: List[str], n: int = 2) -> Generator:
)
if self.graph_transformation_funcs is not None:
graphs = [self.transform_graphein_graphs(g) for g in graphs]

# Convert to PyTorch Geometric Data
graphs = [self.graph_format_convertor(g) for g in graphs]

Expand Down
3 changes: 2 additions & 1 deletion graphein/protein/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ def download_alphafold_structure(
query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.cif"
if pdb:
query_url = f"{BASE_URL}AF-{uniprot_id}-F1-model_v{version}.pdb"
print(query_url)
a-r-j marked this conversation as resolved.
Show resolved Hide resolved
structure_filename = wget.download(query_url, out=out_dir)

if rename:
extension = ".pdb" if pdb else ".cif"
os.rename(
Expand Down
Loading