Skip to content

Commit

Permalink
add mistral 7b
Browse files Browse the repository at this point in the history
  • Loading branch information
moritztng committed Mar 7, 2024
1 parent 6b3fa7b commit a118532
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 107 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
target
env
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
Like grep but for natural language questions. Based on Mixtral 8x7B. ~15 tokens/s on Nvidia RTX 3070 with 8GB memory.
Like grep but for natural language questions. Based on Mistral 7B or 8x7B. ~15 tokens/s on Nvidia RTX 3070 with 8GB memory.

# Installation
## Linux x86_64
If nvidia driver that supports cuda 12.1 exists, it installs cuda version, else cpu version. It's ~48GB.
If nvidia driver that supports cuda 12.1 exists, it installs cuda version, else cpu version. Replace `small` with `large` to install Mixtral 8x7B. It's ~7GB or ~48GB.
```bash
curl https://raw.githubusercontent.com/moritztng/fltr/main/install.sh -o install.sh && bash install.sh && source ~/.bashrc
curl https://raw.githubusercontent.com/moritztng/fltr/main/install.sh -o install.sh && bash install.sh small && source ~/.bashrc
```

# Quickstart
Add `--large` for Mixtral 8x7B.
```bash
fltr --file emails.txt --prompt "Is the following email spam? Email:" --batch-size 32
```
Expand Down
70 changes: 47 additions & 23 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,97 @@
import numpy as np
from argparse import ArgumentParser


def serialize_fp32(file, tensor):
""" writes one fp32 tensor to file that is open in wb mode """
"""writes one fp32 tensor to file that is open in wb mode"""
d = tensor.detach().cpu().view(-1).to(torch.float32).numpy()
b = struct.pack(f'{len(d)}f', *d)
b = struct.pack(f"{len(d)}f", *d)
file.write(b)


def serialize_int8(file, tensor):
""" writes one int8 tensor to file that is open in wb mode """
"""writes one int8 tensor to file that is open in wb mode"""
d = tensor.detach().cpu().view(-1).numpy().astype(np.int8)
b = struct.pack(f'{len(d)}b', *d)
b = struct.pack(f"{len(d)}b", *d)
file.write(b)


def quantize_serialize(f, w, group_size):
"""
takes a tensor and returns the Q8_0 quantized version
i.e. symmetric quantization into int8, range [-127,127]
"""
assert w.numel() % group_size == 0
ori_shape = w.shape
w = w.float() # convert to float32
w = w.float() # convert to float32
w = w.reshape(-1, group_size)
# find the max in each group
wmax = torch.abs(w).max(dim=1).values
# calculate the scaling factor such that float = quant * scale
scale = wmax / 127.0
# scale into range [-127, 127]
quant = w / scale[:,None]
quant = w / scale[:, None]
# round to nearest integer
int8val = torch.round(quant).to(torch.int8)
# dequantize by rescaling
fp32val = (int8val.float() * scale[:,None]).view(-1)
fp32val = (int8val.float() * scale[:, None]).view(-1)
fp32valr = fp32val.reshape(-1, group_size)
# calculate the max error in each group
err = torch.abs(fp32valr - w).max(dim=1).values
# find the max error across all groups
maxerr = err.max().item()

serialize_int8(f, int8val)
serialize_fp32(f, scale)
return maxerr

parser = ArgumentParser(prog="llama 2 converter")

parser = ArgumentParser(prog="mistral converter")
parser.add_argument("output_path", type=str)
parser.add_argument("--checkpoint", type=str)
parser.add_argument("--group-size", default=64, type=int)
parser.add_argument("--moe", action="store_true")
parser.add_argument("--cuda", action="store_true")
args = parser.parse_args()

state_dict = torch.load(args.checkpoint, map_location="cpu", mmap=True)
state_dict = torch.load(args.checkpoint, map_location="cuda" if args.cuda else "cpu", mmap=True)
with open(args.output_path, "wb") as f:
serialize_fp32(f, state_dict['norm.weight'])
serialize_fp32(f, state_dict["norm.weight"])
print("norm.weight")
err = quantize_serialize(f, state_dict['tok_embeddings.weight'], args.group_size)
err = quantize_serialize(f, state_dict["tok_embeddings.weight"], args.group_size)
print(f"tok_embeddings.weight, error: {err}")
err = quantize_serialize(f, state_dict['output.weight'], args.group_size)
err = quantize_serialize(f, state_dict["output.weight"], args.group_size)
print(f"output.weight, error: {err}")
for i in range(32):
layer_prefix = f'layers.{i}.'
layer_prefix = f"layers.{i}."
print(layer_prefix)
for name in ['attention_norm.weight', 'ffn_norm.weight']:
for name in ["attention_norm.weight", "ffn_norm.weight"]:
serialize_fp32(f, state_dict[layer_prefix + name])
print(name)
for name in ['attention.wq.weight', 'attention.wk.weight', 'attention.wv.weight', 'attention.wo.weight', 'feed_forward.gate.weight']:
err = quantize_serialize(f, state_dict[layer_prefix + name], args.group_size)
for name in [
"attention.wq.weight",
"attention.wk.weight",
"attention.wv.weight",
"attention.wo.weight",
] + (
["feed_forward.gate.weight"]
if args.moe
else [
"feed_forward.w1.weight",
"feed_forward.w2.weight",
"feed_forward.w3.weight",
]
):
err = quantize_serialize(
f, state_dict[layer_prefix + name], args.group_size
)
print(f"{name}, error: {err}")
for e in range(8):
expert_prefix = layer_prefix + f"feed_forward.experts.{e}."
print(expert_prefix)
for name in ['w1.weight', 'w2.weight', 'w3.weight']:
err = quantize_serialize(f, state_dict[expert_prefix + name], args.group_size)
print(f"{name}, error: {err}")
if args.moe:
for e in range(8):
expert_prefix = layer_prefix + f"feed_forward.experts.{e}."
print(expert_prefix)
for name in ["w1.weight", "w2.weight", "w3.weight"]:
err = quantize_serialize(
f, state_dict[expert_prefix + name], args.group_size
)
print(f"{name}, error: {err}")
10 changes: 9 additions & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,15 @@ fi
INSTALL_DIR=~/Fltr
mkdir -p "$INSTALL_DIR"
curl -sSL https://github.com/moritztng/fltr/releases/download/v0.1-alpha/fltr-0.1-x86_64-${processor}.gz | gunzip > "$INSTALL_DIR/fltr"
curl -L https://huggingface.co/moritztng/Mixtral-8x7B-Instruct-v0.1/resolve/main/{weights.bin,tokenizer.json} -o "$INSTALL_DIR/weights.bin" -o "$INSTALL_DIR/tokenizer.json"

MODEL_URL=https://huggingface.co/moritztng/fltr/resolve/main
curl -L "$MODEL_URL/tokenizer.json" -o "$INSTALL_DIR/tokenizer.json"
if [[ ",$1," == *",small,"* ]]; then
curl -L "$MODEL_URL/mistral-7b-instruct-v0.2.bin" -o "$INSTALL_DIR/small.bin"
fi
if [[ ",$1," == *",large,"* ]]; then
curl -L "$MODEL_URL/mixtral-8x7b-instruct-v0.1.bin" -o "$INSTALL_DIR/large.bin"
fi
chmod +x "$INSTALL_DIR/fltr"

if [[ ":$PATH:" != *":$INSTALL_DIR:"* ]]; then
Expand Down
180 changes: 102 additions & 78 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ struct Layer {
heads: QuantizedSlice<'static>,
rms_attention: &'static [f32],
rms_feedforward: &'static [f32],
gate: QuantizedSlice<'static>,
gate: Option<QuantizedSlice<'static>>,
experts: Vec<Expert>,
}

Expand Down Expand Up @@ -304,13 +304,13 @@ fn cache_kv(cache: &mut [f32], kv: &[f32], cache_lens: &[usize], kv_lens: &[usiz
}

impl Model {
pub fn from_dir(path: &Path) -> Model {
pub fn from_dir(path: &Path, multiple_experts: bool) -> Model {
#[cfg(feature = "cuda")]
unsafe {
cuda_init()
};
let mmap: MmapRaw = MmapOptions::new()
.map_raw_read_only(&File::open(path.join("weights.bin")).unwrap())
.map_raw_read_only(&File::open(path.join(format!("{}.bin", if multiple_experts {"large"} else {"small"}))).unwrap())
.unwrap();
let mut weights_ptr = mmap.as_ptr() as *const u8;
let rms_final = ptr_to_slice::<DIM>(&mut weights_ptr);
Expand All @@ -328,9 +328,15 @@ impl Model {
let value =
QuantizedSlice::from_ptr::<{ DIM * N_KV_HEADS * HEAD_SIZE }>(&mut weights_ptr);
let heads = QuantizedSlice::from_ptr::<{ DIM * DIM }>(&mut weights_ptr);
let gate = QuantizedSlice::from_ptr::<{ DIM * N_EXPERTS }>(&mut weights_ptr);
let gate = if multiple_experts {
Some(QuantizedSlice::from_ptr::<{ DIM * N_EXPERTS }>(
&mut weights_ptr,
))
} else {
None
};
let mut experts = Vec::new();
for _ in 0..N_EXPERTS {
for _ in 0..(if multiple_experts { N_EXPERTS } else { 1 }) {
experts.push(Expert {
ff1: QuantizedSlice::from_ptr::<{ DIM * HIDDEN_DIM }>(&mut weights_ptr),
ff2: QuantizedSlice::from_ptr::<{ HIDDEN_DIM * DIM }>(&mut weights_ptr),
Expand Down Expand Up @@ -537,91 +543,109 @@ impl Model {
&weights.rms_feedforward,
DIM,
);

quantize(&mut buffer.qstate, &buffer.state2);
matmul::<DIM, N_EXPERTS>(
&mut buffer.expert_logits,
&buffer.qstate.slice_full(),
&weights.gate,
);

let mut expert_tokens: [Vec<(usize, f32)>; 8] = Default::default();
for (p, token_expert_logits) in
buffer.expert_logits.chunks_exact_mut(N_EXPERTS).enumerate()
{
let mut indices_logits: Vec<_> = token_expert_logits.iter().enumerate().collect();
indices_logits
.sort_unstable_by(|(_, logit1), (_, logit2)| logit2.total_cmp(logit1));
let (expert_indices, mut expert_weights): (Vec<_>, Vec<_>) =
indices_logits.into_iter().take(N_EXPERTS_PER_TOKEN).unzip();
softmax(&mut expert_weights);
for (expert_index, expert_weight) in
expert_indices.iter().zip(expert_weights.iter())
if let Some(gate_weights) = &weights.gate {
matmul::<DIM, N_EXPERTS>(
&mut buffer.expert_logits,
&buffer.qstate.slice_full(),
gate_weights,
);
let mut expert_tokens: [Vec<(usize, f32)>; 8] = Default::default();
for (p, token_expert_logits) in
buffer.expert_logits.chunks_exact_mut(N_EXPERTS).enumerate()
{
expert_tokens[*expert_index].push((p, *expert_weight));
let mut indices_logits: Vec<_> =
token_expert_logits.iter().enumerate().collect();
indices_logits
.sort_unstable_by(|(_, logit1), (_, logit2)| logit2.total_cmp(logit1));
let (expert_indices, mut expert_weights): (Vec<_>, Vec<_>) =
indices_logits.into_iter().take(N_EXPERTS_PER_TOKEN).unzip();
softmax(&mut expert_weights);
for (expert_index, expert_weight) in
expert_indices.iter().zip(expert_weights.iter())
{
expert_tokens[*expert_index].push((p, *expert_weight));
}
}
}

for (expert_index, token_weights) in expert_tokens.iter().enumerate() {
if token_weights.is_empty() {
continue;
}
for (expert_index, token_weights) in expert_tokens.iter().enumerate() {
if token_weights.is_empty() {
continue;
}

let expert = &weights.experts[expert_index];
let n_tokens = token_weights.len();
let expert_qstate = buffer.qstate2.slice_mut(0, n_tokens * DIM);
for ((state_values, state_scales), (token_index, _)) in expert_qstate
.values
.chunks_exact_mut(DIM)
.zip(expert_qstate.scales.chunks_exact_mut(DIM / Q_GROUP_SIZE))
.zip(token_weights.iter())
{
state_values.copy_from_slice(
&buffer
.qstate
.values
.chunks_exact(DIM)
.nth(*token_index)
.unwrap(),
);
state_scales.copy_from_slice(
&buffer
.qstate
.scales
.chunks_exact(DIM / Q_GROUP_SIZE)
.nth(*token_index)
.unwrap(),
let expert = &weights.experts[expert_index];
let n_tokens = token_weights.len();
let expert_qstate = buffer.qstate2.slice_mut(0, n_tokens * DIM);
for ((state_values, state_scales), (token_index, _)) in expert_qstate
.values
.chunks_exact_mut(DIM)
.zip(expert_qstate.scales.chunks_exact_mut(DIM / Q_GROUP_SIZE))
.zip(token_weights.iter())
{
state_values.copy_from_slice(
&buffer
.qstate
.values
.chunks_exact(DIM)
.nth(*token_index)
.unwrap(),
);
state_scales.copy_from_slice(
&buffer
.qstate
.scales
.chunks_exact(DIM / Q_GROUP_SIZE)
.nth(*token_index)
.unwrap(),
);
}
let expert_qstate = buffer.qstate2.slice(0, n_tokens * DIM);
let expert_ff_hidden = &mut buffer.ff_hidden[..n_tokens * HIDDEN_DIM];
let expert_swiglu = &mut buffer.swiglu[..n_tokens * HIDDEN_DIM];
matmul::<DIM, HIDDEN_DIM>(expert_ff_hidden, &expert_qstate, &expert.ff1);
matmul::<DIM, HIDDEN_DIM>(expert_swiglu, &expert_qstate, &expert.swiglu);
for (hidden_x, swiglu_x) in
expert_ff_hidden.iter_mut().zip(expert_swiglu.iter())
{
*hidden_x *= 1f32 / (1f32 + (-*hidden_x).exp());
*hidden_x *= swiglu_x;
}
quantize(&mut buffer.qhidden, &expert_ff_hidden);
matmul::<HIDDEN_DIM, DIM>(
&mut buffer.state2[..n_tokens * DIM],
&buffer.qhidden.slice(0, n_tokens * HIDDEN_DIM),
&expert.ff2,
);
for (token_state, (token_index, weight)) in buffer.state2[..n_tokens * DIM]
.chunks_exact_mut(DIM)
.zip(token_weights.iter())
{
smul(token_state, *weight);
add(
&mut buffer
.state
.chunks_exact_mut(DIM)
.nth(*token_index)
.unwrap(),
token_state,
);
}
}
let expert_qstate = buffer.qstate2.slice(0, n_tokens * DIM);
let expert_ff_hidden = &mut buffer.ff_hidden[..n_tokens * HIDDEN_DIM];
let expert_swiglu = &mut buffer.swiglu[..n_tokens * HIDDEN_DIM];
matmul::<DIM, HIDDEN_DIM>(expert_ff_hidden, &expert_qstate, &expert.ff1);
matmul::<DIM, HIDDEN_DIM>(expert_swiglu, &expert_qstate, &expert.swiglu);
for (hidden_x, swiglu_x) in expert_ff_hidden.iter_mut().zip(expert_swiglu.iter()) {
} else {
matmul::<DIM, HIDDEN_DIM>(&mut buffer.ff_hidden, &buffer.qstate.slice_full(), &weights.experts[0].ff1);
matmul::<DIM, HIDDEN_DIM>(&mut buffer.swiglu, &buffer.qstate.slice_full(), &weights.experts[0].swiglu);
for (hidden_x, swiglu_x) in buffer.ff_hidden.iter_mut().zip(buffer.swiglu.iter()) {
*hidden_x *= 1f32 / (1f32 + (-*hidden_x).exp());
*hidden_x *= swiglu_x;
}
quantize(&mut buffer.qhidden, &expert_ff_hidden);
quantize(&mut buffer.qhidden, &buffer.ff_hidden);
matmul::<HIDDEN_DIM, DIM>(
&mut buffer.state2[..n_tokens * DIM],
&buffer.qhidden.slice(0, n_tokens * HIDDEN_DIM),
&expert.ff2,
&mut buffer.state2,
&buffer.qhidden.slice_full(),
&weights.experts[0].ff2,
);
for (token_state, (token_index, weight)) in buffer.state2[..n_tokens * DIM]
.chunks_exact_mut(DIM)
.zip(token_weights.iter())
{
smul(token_state, *weight);
add(
&mut buffer
.state
.chunks_exact_mut(DIM)
.nth(*token_index)
.unwrap(),
token_state,
);
}
add(&mut buffer.state, &buffer.state2);
}
}

Expand Down
Loading

0 comments on commit a118532

Please sign in to comment.