Skip to content

Commit

Permalink
v3.0: Introduce long context model Chinese-LLaMA-2-7B/13B-16K (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
ymcui authored Aug 25, 2023
2 parents 0c9a5b3 + 96e4bfc commit b6ef97d
Show file tree
Hide file tree
Showing 9 changed files with 785 additions and 114 deletions.
125 changes: 82 additions & 43 deletions README.md

Large diffs are not rendered by default.

124 changes: 80 additions & 44 deletions README_EN.md

Large diffs are not rendered by default.

8 changes: 7 additions & 1 deletion scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ A server that implements OPENAI API using fastapi, Wiki: [https://github.com/ymc

C-Eval评测脚本,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/ceval_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/ceval_zh)

Inference Script for C-Eval, Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/ceval_en
Inference script for C-Eval, Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/ceval_en

### cmmlu/

CMMLU评测脚本,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/cmmlu_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/cmmlu_zh)

Inference script for CMMLU, Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/cmmlu_en

### llama-cpp/

Expand Down
66 changes: 44 additions & 22 deletions scripts/attn_and_long_ctx_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
USE_MEM_EFF_ATTENTION = False
ALPHA = 1.0
AUTO_COEFF = 1.0
SCALING_FACTOR = None


def apply_rotary_pos_emb_single(q, cos, sin, position_ids):
Expand Down Expand Up @@ -124,10 +125,25 @@ def xformers_forward(

old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__

def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None):
self.dim = dim

def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
t = t / self.scaling_factor

freqs = torch.einsum("i,j->ij", t, self.ntk_inv_freq.to(device))
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)


def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=None):
self.alpha = ALPHA
self.max_position_embeddings = max_position_embeddings
if SCALING_FACTOR is None:
self.scaling_factor = scaling_factor or 1.0
else:
self.scaling_factor = SCALING_FACTOR
if isinstance(ALPHA,(float,int)):
base = base * ALPHA ** (dim / (dim-2))
self.base = base
Expand All @@ -136,24 +152,21 @@ def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, devic
else:
raise ValueError(ALPHA)
old_init(self, dim, max_position_embeddings, base, device)
ntk_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
self.register_buffer("ntk_inv_freq", ntk_inv_freq, persistent=False)
self.ntk_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))

self._set_cos_sin_cache = _set_cos_sin_cache
self._set_cos_sin_cache(
self, seq_len=max_position_embeddings, device=self.ntk_inv_freq.device, dtype=torch.get_default_dtype()
)


def adaptive_ntk_forward(self, x, seq_len=None):
if seq_len > self.max_seq_len_cached:
if isinstance(self.alpha,(float,int)):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device, dtype=self.ntk_inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.ntk_inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
self._set_cos_sin_cache(self, seq_len=seq_len, device=x.device, dtype=x.dtype)
elif self.alpha=='auto':
t = torch.arange(seq_len, device=x.device, dtype=self.ntk_inv_freq.dtype)
t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
t = t / self.scaling_factor
dim = self.dim
alpha = (seq_len / (self.max_position_embeddings/2) - 1) * AUTO_COEFF
base = self.base * alpha ** (dim / (dim-2))
Expand All @@ -167,11 +180,10 @@ def adaptive_ntk_forward(self, x, seq_len=None):
cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
else:
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
)


def apply_attention_patch(
Expand All @@ -187,15 +199,25 @@ def apply_attention_patch(
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward


def apply_ntk_scaling_patch(alpha: Union[float,str]):
def apply_ntk_scaling_patch(alpha: Union[float,str], scaling_factor: Optional[float] = None):
global ALPHA
global SCALING_FACTOR
ALPHA = alpha
SCALING_FACTOR = scaling_factor
try:
ALPHA = float(ALPHA)
except ValueError:
if ALPHA!="auto":
raise ValueError(f"Alpha can only be a float or 'auto', but given {ALPHA}")
print(f"Apply NTK scaling with ALPHA={ALPHA}")
if scaling_factor is None:
print(f"The value of scaling factor will be read from model config file, or set to 1.")
else:
print(f"Warning: scaling factor is set to {SCALING_FACTOR}. \
If you set the value by hand, do not forget to update \
max_position_embeddings in the model config file.")

transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init
if hasattr(transformers.models.llama.modeling_llama,'LlamaLinearScalingRotaryEmbedding'):
transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ = adaptive_ntk_init
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward
148 changes: 148 additions & 0 deletions scripts/cmmlu/categories.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# This code is modified from CMMLU Project: https://github.com/haonan-li/CMMLU
name_en2zh = {
"agronomy": "农学",
"anatomy": "解剖学",
"ancient_chinese": "古汉语",
"arts": "艺术学",
"astronomy": "天文学",
"business_ethics": "商业伦理",
"chinese_civil_service_exam": "中国公务员考试",
"chinese_driving_rule": "中国驾驶规则",
"chinese_food_culture": "中国饮食文化",
"chinese_foreign_policy": "中国外交政策",
"chinese_history":"中国历史",
"chinese_literature": "中国文学",
"chinese_teacher_qualification": "中国教师资格",
"clinical_knowledge": "临床知识",
"college_actuarial_science":"大学精算学",
"college_education":"大学教育学",
"college_engineering_hydrology": "大学工程水文学",
"college_law": "大学法律",
"college_mathematics": "大学数学",
"college_medical_statistics":"大学医学统计",
"college_medicine": "大学医学",
"computer_science": "计算机科学",
"computer_security": "计算机安全",
"conceptual_physics": "概念物理学",
"construction_project_management": "建设工程管理",
"economics": "经济学",
"education": "教育学",
"electrical_engineering": "电气工程",
"elementary_chinese":"小学语文",
"elementary_commonsense":"小学常识",
"elementary_information_and_technology": "小学信息技术",
"elementary_mathematics": "初等数学",
"ethnology": "民族学",
"food_science": "食品科学",
"genetics": "遗传学",
"global_facts": "全球事实",
"high_school_biology": "高中生物",
"high_school_chemistry": "高中化学",
"high_school_geography": "高中地理",
"high_school_mathematics": "高中数学",
"high_school_physics": "高中物理学",
"high_school_politics": "高中政治",
"human_sexuality": "人类性行为",
"international_law": "国际法学",
"journalism": "新闻学",
"jurisprudence": "法理学",
"legal_and_moral_basis": "法律与道德基础",
"logical": "逻辑学",
"machine_learning": "机器学习",
"management": "管理学",
"marketing": "市场营销",
"marxist_theory": "马克思主义理论",
"modern_chinese": "现代汉语",
"nutrition": "营养学",
"philosophy": "哲学",
"professional_accounting": "专业会计",
"professional_law": "专业法学",
"professional_medicine": "专业医学",
"professional_psychology": "专业心理学",
"public_relations": "公共关系",
"security_study":"安全研究",
"sociology": "社会学",
"sports_science": "体育学",
"traditional_chinese_medicine": "中医中药",
"virology": "病毒学",
"world_history":"世界历史",
"world_religions": "世界宗教",
}

subcategories = {
"agronomy": ['other'],
"anatomy": ['biology'],
"ancient_chinese": ['linguistics','china specific'],
"arts": ['arts'],
"astronomy": ['physics'],
"business_ethics": ['business'],
"chinese_civil_service_exam": ['politics','china specific'],
"chinese_driving_rule": ['other','china specific'],
"chinese_food_culture": ['culture','china specific'],
"chinese_foreign_policy": ['politics','china specific'],
"chinese_history":['history','china specific'],
"chinese_literature": ['literature','china specific'],
"chinese_teacher_qualification": ['education','china specific'],
"college_actuarial_science":['math'],
"college_education":['education'],
"college_engineering_hydrology": ['engineering'],
"college_law": ['law'],
"college_mathematics": ['math'],
"college_medical_statistics":['statistics'],
"clinical_knowledge": ['other'],
"college_medicine": ['other'],
"computer_science": ['computer science'],
"computer_security": ['other'],
"conceptual_physics": ['physics'],
"construction_project_management": ['other','china specific'],
"economics": ['economics'],
"education": ['education'],
"elementary_chinese":['linguistics','china specific'],
"elementary_commonsense":['other','china specific'],
"elementary_information_and_technology": ['other'],
"electrical_engineering": ['engineering'],
"elementary_mathematics": ['math'],
"ethnology": ['culture','china specific'],
"food_science": ['other'],
"genetics": ['biology'],
"global_facts": ['global'],
"high_school_biology": ['biology'],
"high_school_chemistry": ['chemistry'],
"high_school_geography": ['geography'],
"high_school_mathematics": ['math'],
"high_school_physics": ['physics'],
"high_school_politics": ['politics','china specific'],
"human_sexuality": ['other'],
"international_law": ['law'],
"journalism": ['sociology'],
"jurisprudence": ['law'],
"legal_and_moral_basis": ['other'],
"logical": ['philosophy'],
"machine_learning": ['computer science'],
"management": ['business'],
"marketing": ['business'],
"marxist_theory": ['philosophy'],
"modern_chinese": ['linguistics','china specific'],
"nutrition": ['other'],
"philosophy": ['philosophy'],
"professional_accounting": ['business'],
"professional_law": ['law'],
"professional_medicine": ['other'],
"professional_psychology": ['psychology'],
"public_relations": ['politics'],
"security_study": ['politics'],
"sociology": ['culture'],
"sports_science": ['other'],
"traditional_chinese_medicine": ['other','china specific'],
"virology": ['biology'],
"world_history":['history'],
"world_religions": ['global'],
}

categories = {
"STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"],
"Humanities": ["history", "philosophy", "law", "arts", "literature", "global"],
"Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"],
"Other":["other"],
"China specific": ["china specific"],
}
Loading

0 comments on commit b6ef97d

Please sign in to comment.