Skip to content

Commit

Permalink
Merge pull request #22 from vtuber-plan/function
Browse files Browse the repository at this point in the history
support gptq quantization and update transformers version
  • Loading branch information
FrostMiKu authored Sep 15, 2023
2 parents 63a2a34 + e858a57 commit a19cd7f
Show file tree
Hide file tree
Showing 17 changed files with 484 additions and 81 deletions.
54 changes: 26 additions & 28 deletions langport/data/conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class SeparatorStyle(Enum):
ADD_COLON_TWO = auto()
ADD_COLON_SPACE_SINGLE = auto()
NO_COLON_SINGLE = auto()
NO_COLON_TWO = auto()
ADD_NEW_LINE_SINGLE = auto()
DOLLY = auto()
RWKV = auto()
Expand All @@ -34,8 +35,9 @@ class ConversationSettings:
sep_style: SeparatorStyle
sep: str
sep2: Optional[str] = None
system_sep: Optional[str] = None
round_sep: Optional[str] = None
# The template of the system prompt
system_template: str = "{system_message}"
# Stop criteria (the default one is EOS token)
stop_str: str = None
# Stops generation if meeting any token in this list
Expand All @@ -45,10 +47,10 @@ def copy(self):
return ConversationSettings(
name=self.name,
roles=self.roles,
system_template=self.system_template,
sep_style=self.sep_style,
sep=self.sep,
sep2=self.sep2,
system_sep=self.system_sep,
round_sep=self.round_sep,
stop_str=self.stop_str,
stop_token_ids=self.stop_token_ids,
Expand All @@ -74,12 +76,9 @@ def get_prompt(self) -> str:
round_sep = self.settings.round_sep
else:
round_sep = ""
system_prompt = self.settings.system_template.format(system_message=self.system)
if self.settings.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
if self.settings.system_sep is not None:
ret = self.system + self.settings.system_sep
else:
ret = self.system + self.settings.sep

ret = system_prompt + self.settings.sep
for i, (role, message) in enumerate(self.messages):
if i % len(self.settings.roles) == 0:
ret += round_sep
Expand All @@ -90,10 +89,7 @@ def get_prompt(self) -> str:
return ret
elif self.settings.sep_style == SeparatorStyle.ADD_COLON_TWO:
seps = [self.settings.sep, self.settings.sep2]
if self.settings.system_sep is not None:
ret = self.system + self.settings.system_sep
else:
ret = self.system + seps[0]
ret = system_prompt + seps[0]

for i, (role, message) in enumerate(self.messages):
if i % len(self.settings.roles) == 0:
Expand All @@ -104,10 +100,7 @@ def get_prompt(self) -> str:
ret += role + ": "
return ret
elif self.settings.sep_style == SeparatorStyle.NO_COLON_SINGLE:
if self.settings.system_sep is not None:
ret = self.system + self.settings.system_sep
else:
ret = self.system
ret = system_prompt + self.settings.sep

for i, (role, message) in enumerate(self.messages):
if i % len(self.settings.roles) == 0:
Expand All @@ -117,11 +110,17 @@ def get_prompt(self) -> str:
else:
ret += role
return ret
elif self.settings.sep_style == SeparatorStyle.NO_COLON_TWO:
seps = [self.settings.sep, self.settings.sep2]
ret = system_prompt + seps[0]
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + message + seps[i % 2]
else:
ret += role
return ret
elif self.settings.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
if self.settings.system_sep is not None:
ret = self.system + self.settings.system_sep
else:
ret = self.system + self.settings.sep
ret = "" if system_prompt == "" else system_prompt + self.settings.sep

for i, (role, message) in enumerate(self.messages):
if i % len(self.settings.roles) == 0:
Expand All @@ -134,7 +133,7 @@ def get_prompt(self) -> str:
return ret
elif self.settings.sep_style == SeparatorStyle.DOLLY:
seps = [self.settings.sep, self.settings.sep2]
ret = self.system
ret = system_prompt
for i, (role, message) in enumerate(self.messages):
if message:
ret += role + ":\n" + message + seps[i % 2]
Expand All @@ -144,7 +143,7 @@ def get_prompt(self) -> str:
ret += role + ":\n"
return ret
elif self.settings.sep_style == SeparatorStyle.RWKV:
ret = self.system
ret = system_prompt + self.settings.sep
for i, (role, message) in enumerate(self.messages):
if message:
ret += (
Expand All @@ -157,15 +156,18 @@ def get_prompt(self) -> str:
ret += role + ":"
return ret
elif self.settings.sep_style == SeparatorStyle.PHOENIX:
ret = self.system
ret = system_prompt
for role, message in self.messages:
if message:
ret += role + ": " + "<s>" + message + "</s>"
else:
ret += role + ": " + "<s>"
return ret
elif self.settings.sep_style == SeparatorStyle.CHATGLM:
ret = self.system
if system_prompt:
ret = system_prompt + self.settings.sep
else:
ret = ""
for i, (role, message) in enumerate(self.messages):
if message:
if i % 2 == 0:
Expand Down Expand Up @@ -193,11 +195,7 @@ def get_prompt(self) -> str:
return ret
elif self.settings.sep_style == SeparatorStyle.CHATLM:
im_start, im_end = "<|im_start|>", "<|im_end|>"
ret = im_start + "system" + "\n" + self.system + im_end
if self.settings.system_sep is not None:
ret += self.settings.system_sep
else:
ret += self.settings.sep
ret = system_prompt + self.settings.sep

for i, (role, message) in enumerate(self.messages):
ret += im_start
Expand Down
17 changes: 17 additions & 0 deletions langport/data/conversation/settings/baichuan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from langport.data.conversation import (
ConversationSettings,
SeparatorStyle,
)

"""
https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/modeling_baichuan.py#L555
"""
# Baichuan default template
baichuan = ConversationSettings(
name="baichuan",
roles=(" <reserved_102> ", " <reserved_103> "),
sep_style=SeparatorStyle.NO_COLON_TWO,
sep="",
sep2="</s>",
stop_token_ids=[2, 195],
)
4 changes: 2 additions & 2 deletions langport/data/conversation/settings/internlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
roles=("<|User|>", "<|Bot|>"),
sep_style=SeparatorStyle.ADD_COLON_TWO,
round_sep="<s>",
system_sep="",
sep="<eoh>\n",
sep2="<eoa>\n",
stop_str="<eoa>"
stop_str="<eoa>",
stop_token_ids=[1, 2]
)
2 changes: 1 addition & 1 deletion langport/data/conversation/settings/openbuddy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
name="openbuddy",
roles=("User", "Assistant"),
sep_style=SeparatorStyle.ADD_COLON_SINGLE,
system_sep="\n\n",
system_template="{system_message}\n",
sep="\n",
stop_str="\n</s>",
)
9 changes: 6 additions & 3 deletions langport/data/conversation/settings/qwen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
SeparatorStyle,
)


one_shot = ConversationSettings(
"""
https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/qwen_generation_utils.py#L119
"""
qwen = ConversationSettings(
name="qwen",
roles=("user", "assistant"),
sep_style=SeparatorStyle.CHATLM,
system_template="<|im_start|>system\n{system_message}<|im_end|>",
sep="\n",
stop_str="<|im_end|>",
)
)
6 changes: 3 additions & 3 deletions langport/model/adapters/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ class BaichuanAdapter(BaseAdapter):
"""The model adapter for baichuan-inc/baichuan-7B"""

def match(self, model_path: str):
return "baichuan" in model_path
return "baichuan" in model_path.lower()

def get_default_conv_template(self, model_path: str) -> ConversationHistory:
settings = get_conv_settings("one_shot")
settings = get_conv_settings("baichuan")
return ConversationHistory(
system="",
messages=(),
messages=[],
offset=0,
settings=settings,
)
Loading

0 comments on commit a19cd7f

Please sign in to comment.