From 83f74ddcd23938b37f7041a937a917a009f0f6bd Mon Sep 17 00:00:00 2001 From: lifegpc Date: Mon, 30 Dec 2024 13:22:28 +0800 Subject: [PATCH] Add chat functionality with OpenAI API and configuration support --- .gitignore | 2 + chat.py | 180 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+) create mode 100644 chat.py diff --git a/.gitignore b/.gitignore index e4c7c4b..5dabecb 100644 --- a/.gitignore +++ b/.gitignore @@ -141,3 +141,5 @@ dmypy.json *.yml *.conf !*.example.yaml +*.jsonl +tiktokencache diff --git a/chat.py b/chat.py new file mode 100644 index 0000000..9ce0884 --- /dev/null +++ b/chat.py @@ -0,0 +1,180 @@ +import openai +import yaml +import argparse +import asyncio +import json +from typing import Optional + + +class Config: + def __init__(self, args, yaml_config): + self._args = args + self._cfg = cfg + + @property + def model(self) -> str: + if self._args.model: + return self._args.model + if 'model' in self._cfg and isinstance(self._cfg['model'], str): + return self._cfg['model'] + return 'gpt-4o-mini' + + @property + def max_completion_tokens(self) -> int: + if self._args.max_completion_tokens: + return self._args.max_completion_tokens + if 'max_completion_tokens' in self._cfg and isinstance(self._cfg['max_completion_tokens'], int): # noqa: E501 + return self._cfg['max_completion_tokens'] + return 4096 + + @property + def include_usage(self) -> bool: + if self._args.include_usage: + return self._args.include_usage + if 'include_usage' in self._cfg and isinstance(self._cfg['include_usage'], bool): # noqa: E501 + return self._cfg['include_usage'] + return False + + @property + def output(self) -> Optional[str]: + if self._args.output: + return self._args.output + if 'output' in self._cfg and isinstance(self._cfg['output'], str): + return self._cfg['output'] + return None + + @property + def temperature(self) -> float: + if self._args.temperature: + if self._args.temperature >= 0.0 and self._args.temperature <= 2.0: + return self._args.temperature + if 'temperature' in self._cfg and isinstance(self._cfg['temperature'], float): # noqa: E501 + temperature = self._cfg['temperature'] + if temperature >= 0.0 and temperature <= 2.0: + return temperature + return 1.0 + + @property + def top_p(self) -> float: + if self._args.top_p: + if self._args.top_p >= 0.0 and self._args.top_p <= 1.0: + return self._args.top_p + if 'top_p' in self._cfg and isinstance(self._cfg['top_p'], float): + top_p = self._cfg['top_p'] + if top_p >= 0.0 and top_p <= 1.0: + return top_p + return 1.0 + + @property + def presence_penalty(self) -> float: + if self._args.presence_penalty: + if self._args.presence_penalty >= -2.0 and self._args.presence_penalty <= 2.0: # noqa: E501 + return self._args.presence_penalty + if 'presence_penalty' in self._cfg and isinstance(self._cfg['presence_penalty'], float): # noqa: E501 + presence_penalty = self._cfg['presence_penalty'] + if presence_penalty >= -2.0 and presence_penalty <= 2.0: + return presence_penalty + return 0.0 + + @property + def store(self) -> bool: + if self._args.store: + return self._args.store + if 'store' in self._cfg and isinstance(self._cfg['store'], bool): + return self._cfg['store'] + return False + + +def load_config(file_path): + with open(file_path, 'r', encoding='utf-8') as file: + config = yaml.safe_load(file) + openai.api_key = config['api_key'] + openai.base_url = config['base_url'] + return config + + +def get_user_prompt(): + prompt = "" + print("Enter your prompt (leave empty to finish):") + while True: + line = input('') + if line.strip() == "": + break + prompt += line + "\n" + return prompt.strip() + + +async def stream_response(messages, prompt, args: Config): + client = openai.AsyncClient(api_key=openai.api_key, + base_url=openai.base_url) + messages.append({"role": "user", "content": prompt}) + response = await client.chat.completions.create( + model=args.model, + max_completion_tokens=args.max_completion_tokens, + messages=messages, + stream_options={"include_usage": args.include_usage}, + temperature=args.temperature, + top_p=args.top_p, + presence_penalty=args.presence_penalty, + store=args.store, + stream=True + ) + res = '' + async for chunk in response: + if chunk.choices: + choice = chunk.choices[0] + if choice.delta and choice.delta.content: + data = choice.delta.content + res += data + print(data, end='', flush=True) + print(flush=True) + if chunk.usage: + print(f"Usage: {chunk.usage.to_json(indent=None)}") + return {'role': 'assistant', 'content': res} + + +async def chat(args: Config): + messages = [] + while True: + try: + user_prompt = get_user_prompt() + except KeyboardInterrupt: + break + while True: + try: + res = await stream_response(messages, user_prompt, args) + messages.append(res) + break + except KeyboardInterrupt: + break + except openai.InternalServerError: + print("Internal server error, retrying...") + continue + if args.output and len(messages): + save_to_jsonl(args.output, messages) + + +def save_to_jsonl(file_path, message): + with open(file_path, 'a', encoding='utf-8') as f: + json.dump({'messages': message}, f, ensure_ascii=False, + separators=(',', ':')) + f.write('\n') + + +parser = argparse.ArgumentParser(description="Chat with OpenAI's model.") +parser.add_argument('-m', '--model', type=str, help='Chat model to use') # noqa: E501 +parser.add_argument('-M', '--max-completion-tokens', type=int, help='Maximum length of the response') # noqa: E501 +parser.add_argument('-c', '--config', type=str, default='./chat.yml', help='Path to the configuration file') # noqa: E501 +parser.add_argument('-o', '--output', type=str, help='Path to the output JSONL file') # noqa: E501 +parser.add_argument('-i', '--include-usage', action='store_true', help='Include usage information in the response') # noqa: E501 +parser.add_argument('-t', '--temperature', type=float, help='What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. ') # noqa: E501 +parser.add_argument('-p', '--top-p', type=float, help='An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.') # noqa: E501 +parser.add_argument('-P', '--presence-penalty', type=float, help="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.") # noqa: E501 +parser.add_argument('-s', '--store', action='store_true', help='Whether or not to store the output of this chat completion request for use in our model distillation or evals products.') # noqa: E501 + + +if __name__ == "__main__": + args = parser.parse_args() + cfg = load_config(args.config) + acfg = Config(args, cfg) + asyncio.run(chat(acfg))