Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mini OCP #1

Open
wants to merge 3 commits into
base: mini
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions cdvae/common/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,27 @@
import copy
import itertools

from ase import Atoms
from pymatgen.core.structure import Structure
from pymatgen.core.lattice import Lattice
from pymatgen.analysis.graphs import StructureGraph
from pymatgen.analysis import local_env

from networkx.algorithms.components import is_connected
from pymatgen.io.ase import AseAtomsAdaptor

from sklearn.metrics import accuracy_score, recall_score, precision_score

from torch_scatter import scatter

from p_tqdm import p_umap
import pyarrow as pa


# Tensor of unit cells. Assumes 27 cells in -1, 0, 1 offsets in the x and y dimensions
# Note that differing from OCP, we have 27 offsets here because we are in 3D
from common.parse import atoms_json_loads

OFFSET_LIST = [
[-1, -1, -1],
[-1, -1, 0],
Expand Down Expand Up @@ -87,7 +92,11 @@
def build_crystal(crystal_str, niggli=True, primitive=False):
"""Build crystal from cif string."""
crystal = Structure.from_str(crystal_str, fmt='cif')
return get_crystal_from_structure(crystal, niggli, primitive)


def get_crystal_from_structure(crystal, niggli=True, primitive=False):
"""Build crystal from cif string."""
if primitive:
crystal = crystal.get_primitive_structure()

Expand Down Expand Up @@ -681,6 +690,44 @@ def process_one(row, niggli, primitive, graph_method, prop_list):
return ordered_results


def preprocess_arrow(input_file, num_workers, niggli, primitive, graph_method,
prop_list):
table = pa.RecordBatchFileReader(pa.OSFile(input_file, 'rb')).read_all()
df = table.to_pandas()

def process_one(row, niggli, primitive, graph_method, prop_list):
crystal_str = row['atoms_json']
atoms = Atoms.fromdict(atoms_json_loads(crystal_str))
structure = AseAtomsAdaptor.get_structure(atoms)
crystal = get_crystal_from_structure(
structure, niggli=niggli, primitive=primitive)
graph_arrays = build_crystal_graph(crystal, graph_method)
properties = {k: row[k] for k in prop_list if k in row.keys()}
result_dict = {
'mp_id': row['key'],
'cif': crystal_str,
'graph_arrays': graph_arrays,
'energy_per_atom': atoms.info['energy'] / len(atoms)
}
result_dict.update(properties)
return result_dict

unordered_results = p_umap(
process_one,
[df.iloc[idx] for idx in range(len(df))],
[niggli] * len(df),
[primitive] * len(df),
[graph_method] * len(df),
[prop_list] * len(df),
num_cpus=num_workers)

mpid_to_results = {result['mp_id']: result for result in unordered_results}
ordered_results = [mpid_to_results[df.iloc[idx]['key']]
for idx in range(len(df))]

return ordered_results


def preprocess_tensors(crystal_array_list, niggli, primitive, graph_method):
def process_one(batch_idx, crystal_array, niggli, primitive, graph_method):
frac_coords = crystal_array['frac_coords']
Expand Down
36 changes: 36 additions & 0 deletions cdvae/common/parse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Dict, TextIO

import numpy as np
import json

from ase import Atoms


def atom_json_default(obj: Any) -> Dict:
if hasattr(obj, "todict"):
return obj.todict()
if isinstance(obj, np.ndarray):
return {"_np": True, "_d": obj.tolist()}
raise TypeError


def atom_object_hook(dct: Dict) -> Any:
if "_np" in dct:
return np.array(dct["_d"])
return dct


def atoms_json_dump(atoms: Atoms, file_obj: TextIO) -> None:
return json.dump(atoms, file_obj, default=atom_json_default)


def atoms_json_dumps(atoms: Atoms) -> str:
return json.dumps(atoms, default=atom_json_default)


def atoms_json_load(file_obj: TextIO) -> Atoms:
return Atoms.fromdict(json.load(file_obj, object_hook=atom_object_hook))


def atoms_json_loads(s: str) -> Atoms:
return json.loads(s, object_hook=atom_object_hook)
24 changes: 21 additions & 3 deletions cdvae/pl_data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from typing import Callable

import hydra
import omegaconf
import torch
import pandas as pd
from omegaconf import ValueNode
from torch.utils.data import Dataset
import pyarrow as pa

from torch_geometric.data import Data

from cdvae.common.utils import PROJECT_ROOT
from cdvae.common.data_utils import (
preprocess, preprocess_tensors, add_scaled_lattice_prop)
preprocess, preprocess_tensors, add_scaled_lattice_prop, preprocess_arrow)


class CrystDataset(Dataset):
Expand All @@ -21,14 +24,14 @@ def __init__(self, name: ValueNode, path: ValueNode,
super().__init__()
self.path = path
self.name = name
self.df = pd.read_csv(path)
self.df = self.load_path(path)
self.prop = prop
self.niggli = niggli
self.primitive = primitive
self.graph_method = graph_method
self.lattice_scale_method = lattice_scale_method

self.cached_data = preprocess(
self.cached_data = self.preprocess(
self.path,
preprocess_workers,
niggli=self.niggli,
Expand Down Expand Up @@ -72,6 +75,21 @@ def __getitem__(self, index):
def __repr__(self) -> str:
return f"CrystDataset({self.name=}, {self.path=})"

def load_path(self, path):
return pd.read_csv(path)

def preprocess(self, *args, **kwargs):
return preprocess(*args, **kwargs)


class ArrowCrystDataset(CrystDataset):
def load_path(self, path):
table = pa.RecordBatchFileReader(pa.OSFile('/Users/james/Downloads/ocp-2020', 'rb')).read_all()
return table

def preprocess(self, *args, **kwargs):
return preprocess_arrow(*args, **kwargs)


class TensorCrystDataset(Dataset):
def __init__(self, crystal_array_list, niggli, primitive,
Expand Down
67 changes: 67 additions & 0 deletions conf/data/ocp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
root_path: ${oc.env:PROJECT_ROOT}/data/ocp
prop: energy_per_atom
num_targets: 1
# prop: scaled_lattice
# num_targets: 6
niggli: true
primitive: false
graph_method: crystalnn
lattice_scale_method: scale_length
preprocess_workers: 3
readout: mean
max_atoms: 240
otf_graph: false
eval_model_name: carbon


train_max_epochs: 80
early_stopping_patience: 100000
teacher_forcing_max_epoch: 1000


datamodule:
_target_: cdvae.pl_data.datamodule.CrystDataModule

datasets:
train:
_target_: cdvae.pl_data.dataset.ArrowCrystDataset
name: Formation energy train
path: ${data.root_path}/train.arrow
prop: ${data.prop}
niggli: ${data.niggli}
primitive: ${data.primitive}
graph_method: ${data.graph_method}
lattice_scale_method: ${data.lattice_scale_method}
preprocess_workers: ${data.preprocess_workers}

val:
- _target_: cdvae.pl_data.dataset.CrystDataset
name: Formation energy val
path: ${data.root_path}/val.csv
prop: ${data.prop}
niggli: ${data.niggli}
primitive: ${data.primitive}
graph_method: ${data.graph_method}
lattice_scale_method: ${data.lattice_scale_method}
preprocess_workers: ${data.preprocess_workers}

test:
- _target_: cdvae.pl_data.dataset.CrystDataset
name: Formation energy test
path: ${data.root_path}/test.csv
prop: ${data.prop}
niggli: ${data.niggli}
primitive: ${data.primitive}
graph_method: ${data.graph_method}
lattice_scale_method: ${data.lattice_scale_method}
preprocess_workers: ${data.preprocess_workers}

num_workers:
train: 2
val: 2
test: 2

batch_size:
train: 32
val: 32
test: 32
29 changes: 29 additions & 0 deletions conf/ocp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
expname: test

# metadata specialised for each experiment
core:
version: 0.0.1
tags:
- ${now:%Y-%m-%d}

hydra:
run:
dir: ${oc.env:HYDRA_JOBS}/singlerun/${now:%Y-%m-%d}/${expname}/

sweep:
dir: ${oc.env:HYDRA_JOBS}/multirun/${now:%Y-%m-%d}/${expname}/
subdir: ${hydra.job.num}_${hydra.job.id}

job:
env_set:
WANDB_START_METHOD: thread
WANDB_DIR: ${oc.env:WABDB_DIR}

defaults:
- data: ocp
- model: vae_mini
- logging: offline
- optim: default
- train: cpu
# Decomment this parameter to get parallel job running
# - override hydra/launcher: joblib
36 changes: 36 additions & 0 deletions data/ocp/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Carbon-24

Carbon-24 contains 10k carbon materials, which share the same composition, but have different structures. There is 1 element and the materials have 6 - 24 atoms in the unit cells.

## What is in the dataset?

Carbon-24 includes various carbon structures obtained via *ab initio* random structure searching (AIRSS) (Pickard & Needs, 2006; 2011) performed at 10 GPa.

## Stabiity of curated materials

The original dataset includes 101529 carbon structures, and we selected the 10% of the carbon structure with the lowest energy per atom to create Carbon-24. All 10153 structures are at local energy minimum after DFT relaxation. The most stable structure is diamond at 10 GPa. All remaining structures are thermodynamically unstable but may be kinetically stable.

## Visualization of structures

<p align="center">
<img src="../../assets/carbon_24.png" />
</p>

## Citation

Please consider citing the following paper:

```
@misc{carbon2020data,
doi = {10.24435/MATERIALSCLOUD:2020.0026/V1},
url = {https://archive.materialscloud.org/record/2020.0026/v1},
author = {Pickard, Chris J.},
keywords = {DFT, ab initio random structure searching, carbon},
language = {en},
title = {AIRSS data for carbon at 10GPa and the C+N+H+O system at 1GPa},
publisher = {Materials Cloud},
year = {2020},
copyright = {info:eu-repo/semantics/openAccess}
}
```

Loading