Add chat functionality with OpenAI API and configuration support

This commit is contained in:
2024-12-30 13:22:28 +08:00
parent a6cc95b2ed
commit 83f74ddcd2
2 changed files with 182 additions and 0 deletions

2
.gitignore vendored
View File

@@ -141,3 +141,5 @@ dmypy.json
*.yml *.yml
*.conf *.conf
!*.example.yaml !*.example.yaml
*.jsonl
tiktokencache

180
chat.py Normal file
View File

@@ -0,0 +1,180 @@
import openai
import yaml
import argparse
import asyncio
import json
from typing import Optional
class Config:
def __init__(self, args, yaml_config):
self._args = args
self._cfg = cfg
@property
def model(self) -> str:
if self._args.model:
return self._args.model
if 'model' in self._cfg and isinstance(self._cfg['model'], str):
return self._cfg['model']
return 'gpt-4o-mini'
@property
def max_completion_tokens(self) -> int:
if self._args.max_completion_tokens:
return self._args.max_completion_tokens
if 'max_completion_tokens' in self._cfg and isinstance(self._cfg['max_completion_tokens'], int): # noqa: E501
return self._cfg['max_completion_tokens']
return 4096
@property
def include_usage(self) -> bool:
if self._args.include_usage:
return self._args.include_usage
if 'include_usage' in self._cfg and isinstance(self._cfg['include_usage'], bool): # noqa: E501
return self._cfg['include_usage']
return False
@property
def output(self) -> Optional[str]:
if self._args.output:
return self._args.output
if 'output' in self._cfg and isinstance(self._cfg['output'], str):
return self._cfg['output']
return None
@property
def temperature(self) -> float:
if self._args.temperature:
if self._args.temperature >= 0.0 and self._args.temperature <= 2.0:
return self._args.temperature
if 'temperature' in self._cfg and isinstance(self._cfg['temperature'], float): # noqa: E501
temperature = self._cfg['temperature']
if temperature >= 0.0 and temperature <= 2.0:
return temperature
return 1.0
@property
def top_p(self) -> float:
if self._args.top_p:
if self._args.top_p >= 0.0 and self._args.top_p <= 1.0:
return self._args.top_p
if 'top_p' in self._cfg and isinstance(self._cfg['top_p'], float):
top_p = self._cfg['top_p']
if top_p >= 0.0 and top_p <= 1.0:
return top_p
return 1.0
@property
def presence_penalty(self) -> float:
if self._args.presence_penalty:
if self._args.presence_penalty >= -2.0 and self._args.presence_penalty <= 2.0: # noqa: E501
return self._args.presence_penalty
if 'presence_penalty' in self._cfg and isinstance(self._cfg['presence_penalty'], float): # noqa: E501
presence_penalty = self._cfg['presence_penalty']
if presence_penalty >= -2.0 and presence_penalty <= 2.0:
return presence_penalty
return 0.0
@property
def store(self) -> bool:
if self._args.store:
return self._args.store
if 'store' in self._cfg and isinstance(self._cfg['store'], bool):
return self._cfg['store']
return False
def load_config(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
config = yaml.safe_load(file)
openai.api_key = config['api_key']
openai.base_url = config['base_url']
return config
def get_user_prompt():
prompt = ""
print("Enter your prompt (leave empty to finish):")
while True:
line = input('')
if line.strip() == "":
break
prompt += line + "\n"
return prompt.strip()
async def stream_response(messages, prompt, args: Config):
client = openai.AsyncClient(api_key=openai.api_key,
base_url=openai.base_url)
messages.append({"role": "user", "content": prompt})
response = await client.chat.completions.create(
model=args.model,
max_completion_tokens=args.max_completion_tokens,
messages=messages,
stream_options={"include_usage": args.include_usage},
temperature=args.temperature,
top_p=args.top_p,
presence_penalty=args.presence_penalty,
store=args.store,
stream=True
)
res = ''
async for chunk in response:
if chunk.choices:
choice = chunk.choices[0]
if choice.delta and choice.delta.content:
data = choice.delta.content
res += data
print(data, end='', flush=True)
print(flush=True)
if chunk.usage:
print(f"Usage: {chunk.usage.to_json(indent=None)}")
return {'role': 'assistant', 'content': res}
async def chat(args: Config):
messages = []
while True:
try:
user_prompt = get_user_prompt()
except KeyboardInterrupt:
break
while True:
try:
res = await stream_response(messages, user_prompt, args)
messages.append(res)
break
except KeyboardInterrupt:
break
except openai.InternalServerError:
print("Internal server error, retrying...")
continue
if args.output and len(messages):
save_to_jsonl(args.output, messages)
def save_to_jsonl(file_path, message):
with open(file_path, 'a', encoding='utf-8') as f:
json.dump({'messages': message}, f, ensure_ascii=False,
separators=(',', ':'))
f.write('\n')
parser = argparse.ArgumentParser(description="Chat with OpenAI's model.")
parser.add_argument('-m', '--model', type=str, help='Chat model to use') # noqa: E501
parser.add_argument('-M', '--max-completion-tokens', type=int, help='Maximum length of the response') # noqa: E501
parser.add_argument('-c', '--config', type=str, default='./chat.yml', help='Path to the configuration file') # noqa: E501
parser.add_argument('-o', '--output', type=str, help='Path to the output JSONL file') # noqa: E501
parser.add_argument('-i', '--include-usage', action='store_true', help='Include usage information in the response') # noqa: E501
parser.add_argument('-t', '--temperature', type=float, help='What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. ') # noqa: E501
parser.add_argument('-p', '--top-p', type=float, help='An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.') # noqa: E501
parser.add_argument('-P', '--presence-penalty', type=float, help="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.") # noqa: E501
parser.add_argument('-s', '--store', action='store_true', help='Whether or not to store the output of this chat completion request for use in our model distillation or evals products.') # noqa: E501
if __name__ == "__main__":
args = parser.parse_args()
cfg = load_config(args.config)
acfg = Config(args, cfg)
asyncio.run(chat(acfg))