Skip to content

Commit

Permalink
Merge branch 'main' into document_xattention
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert authored Sep 16, 2024
2 parents a750b45 + 7b7ead9 commit ed51183
Show file tree
Hide file tree
Showing 16 changed files with 369 additions and 106 deletions.
2 changes: 1 addition & 1 deletion docs/nanoset.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ To work with `Nanosets`, we just need to configure 1 argument:

Finally, to use the `Nanosets`, launch the training with [`run_train.py`](../run_train.py).
```shell
torchrun --nproc-per-node 8 run_train.py --config configs/config_nanoset.yaml
torchrun --nproc-per-node 1 run_train.py --config examples/config_nanoset.yaml
```

## Under the hood
Expand Down
4 changes: 4 additions & 0 deletions examples/doremi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,7 @@ For evaluation, we do uniform sampling on the test set to evaluate a 2.5B model
- 2.5B llama trained using the optimized weights: https://huggingface.co/nanotron/doremi-llama-2.5b-optimized-weights

and the dataset: https://huggingface.co/datasets/nanotron/the-pile-for-doremi

#### Thoughts

For DoReMi, it's useful if you don't initially have an idea of what would be a good distribution for your training data, or want a quick way to find a better baseline than the uniform distribution if you want to tune the data distribution by hand. In my previous experiments, DoReMi matched the pretraining performance of the distribution of mamba training but couldn't outperform it. I suspect it doesn't work well when there are nuances, meaning the difference between your known best distribution and a better distribution isn't significant.
12 changes: 12 additions & 0 deletions examples/mamba/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,18 @@ pip install -r requirements.txt

> https://wandb.ai/bouteille/test/reports/Mamba-loss--Vmlldzo2OTgwNDM5
## Bug related to nanotron
Encountered the following issue when ran train_mamba.sh:
```
causal_conv1d_cuda.cpython-310-x86_64-linux-gnu.so: undefined symbol: _ZNK3c1017SymbolicShapeMeta18init_is_contiguousEv
```
Solved this by doing:
pip uninstall mamba-ssm
pip install causal_conv1d==1.1.1
pip install mamba-ssm --no-cache-dir
https://github.com/state-spaces/mamba/issues/169


## Credits
Credits to the following repositories from which the code was adapted:
- https://github.com/state-spaces/mamba
5 changes: 5 additions & 0 deletions examples/mup/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,8 @@ We trained a 350m model with spectral µTransfer and standard parametrization us
Please check the directory [[./examples/mup/configs]](/examples/mup/configs) for the configurations we used to reproduce the experiments.

![LLaMA](./assets/llama.png)


#### Thoughts

For Spectral MuP, the experiments we used it on MLP only [link] and 300m LLaMA [link] (there are links to the experiment config in the mup readme). However, when we tested it on 1B/8B models iirc, the loss blew up for some reasons. So, we'd recommend they try μTransfer, not spectral μTransfer.
3 changes: 2 additions & 1 deletion src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def __post_init__(self):
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs]
dataset: Optional[Union[PretrainDatasetsArgs, NanosetDatasetsArgs, ChatDatasetsArgs]]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down Expand Up @@ -162,6 +162,7 @@ class CheckpointsArgs:
checkpoints_path: Path
checkpoint_interval: int
save_initial_state: Optional[bool] = False
save_final_state: Optional[bool] = False
resume_checkpoint_path: Optional[Path] = None
checkpoints_path_is_shared_file_system: Optional[bool] = False

Expand Down
4 changes: 4 additions & 0 deletions src/nanotron/config/parallelism_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ParallelismArgs:
pp_engine: Pipeline engine to use between "1f1b" and "afab"
tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism
tp_linear_async_communication: Whether to use async communication in TP linear layers
recompute_layer: Whether to recompute each Transformer layer to save memory.
"""

dp: int
Expand All @@ -31,6 +32,9 @@ class ParallelismArgs:
pp_engine: Optional[PipelineEngine] = None
tp_mode: Optional[TensorParallelLinearMode] = None
tp_linear_async_communication: Optional[bool] = None
recompute_layer: bool = False

tp_recompute_allgather: bool = True

expert_parallel_size: int = 1

Expand Down
40 changes: 32 additions & 8 deletions src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union
from typing import Dict, Optional, Union, List

import torch
from torch import nn
from torch.utils.checkpoint import CheckpointFunction

from nanotron import distributed as dist
from nanotron import logging
Expand Down Expand Up @@ -154,6 +155,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=gate_up_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
self.down_proj = TensorParallelRowLinear(
config.intermediate_size,
Expand All @@ -163,8 +165,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER,
)
# TODO @nouamane: why can't we torch.jit.script GLUActivation?
self.split_silu_mul = GLUActivation(config.hidden_act)
self.split_silu_mul = torch.compile(GLUActivation(config.hidden_act))

def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim]
merged_states = self.gate_up_proj(hidden_states)
Expand Down Expand Up @@ -301,6 +302,7 @@ def __init__(
bias=False,
async_communication=tp_linear_async_communication,
contiguous_chunks=qkv_contiguous_chunks,
tp_recompute_allgather=parallel_config.tp_recompute_allgather,
)
# TODO(kunhao): We want to have only one version per device and not one version per layer.
self.rotary_embedding = RotaryEmbedding(
Expand Down Expand Up @@ -591,12 +593,14 @@ def __init__(

self.post_attention_layernorm = TritonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg)

def forward(

self.recompute_layer = parallel_config.recompute_layer

def _core_forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
) -> List[Union[torch.Tensor, TensorPointer]]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

Expand All @@ -609,12 +613,31 @@ def forward(
hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"]
hidden_states = hidden_states + residual

return hidden_states, output["sequence_mask"]

def _checkpointed_forward(
self,
hidden_states: torch.Tensor,
sequence_mask: torch.Tensor,
) -> List[torch.Tensor]:
return CheckpointFunction.apply(self._core_forward, True, hidden_states, sequence_mask)

def forward(
self,
hidden_states: Union[torch.Tensor, TensorPointer],
sequence_mask: Union[torch.Tensor, TensorPointer],
) -> Dict[str, Union[torch.Tensor, TensorPointer]]:

if self.recompute_layer and not isinstance(hidden_states, TensorPointer):
hidden_states, sequence_mask = self._checkpointed_forward(hidden_states, sequence_mask)
else:
hidden_states, sequence_mask = self._core_forward(hidden_states, sequence_mask)

return {
"hidden_states": hidden_states,
"sequence_mask": output["sequence_mask"],
"sequence_mask": sequence_mask,
}


class Embedding(nn.Module, AttachableStore):
def __init__(self, tp_pg: dist.ProcessGroup, config: LlamaConfig, parallel_config: Optional[ParallelismArgs]):
super().__init__()
Expand Down Expand Up @@ -716,6 +739,7 @@ def __init__(
# TODO @thomasw21: refactor so that we store that default in a single place.
"mode": self.tp_mode,
"async_communication": tp_linear_async_communication,
"tp_recompute_allgather": parallel_config.tp_recompute_allgather,
},
module_input_keys={"x"},
module_output_keys={"logits"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
@staticmethod
def backward(ctx, grad_output):
group = ctx.group
return DifferentiableReduceScatterSum.apply(grad_output, group), None
out = DifferentiableReduceScatterSum.apply(grad_output, group)
return out, None


class DifferentiableReduceScatterSum(torch.autograd.Function):
Expand Down Expand Up @@ -113,7 +114,7 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
*rest_size,
device=tensor.device,
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
requires_grad=False,
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor
Expand Down
Loading

0 comments on commit ed51183

Please sign in to comment.