diff --git a/gpt_shows_rename/__main__.py b/gpt_shows_rename/__main__.py index 4ef5e9c..e212d46 100644 --- a/gpt_shows_rename/__main__.py +++ b/gpt_shows_rename/__main__.py @@ -14,8 +14,8 @@ async def main(): print('WARN: TMDB API key is not set, skip TMDB data integration') else: tmdb = TmdbClient(cfg) - tmdb_data = await tmdb.get_tmdb_data(cfg.tmdb_id) - res = await get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id, tmdb_data) + tmdb_data = await tmdb.get_tmdb_data(cfg.tmdb_id, cfg.season_number) + res = await get_response(cfg, cfg.input, files, cfg.series_name, cfg.year, cfg.tmdb_id, cfg.tvdb_id, tmdb_data, cfg.season_number) for f in res.files: print(files[f.index], '->', f.name) input('Continue?') diff --git a/gpt_shows_rename/config.py b/gpt_shows_rename/config.py index 2b4e33f..55d1e37 100644 --- a/gpt_shows_rename/config.py +++ b/gpt_shows_rename/config.py @@ -40,6 +40,10 @@ class Config: def proxy(self) -> Optional[str]: return self._args.proxy if self._args.proxy is not None else self._yaml_config.get('proxy') + @property + def season_number(self) -> Optional[int]: + return self._args.season_number + @property def series_name(self) -> Optional[str]: return self._args.series_name @@ -86,6 +90,7 @@ def get_arg_parser() -> argparse.ArgumentParser: parser.add_argument('-T', '--tvdb-id', type=int, help='TVDB ID (optional)') parser.add_argument('-H', '--hardlink', action='store_true', help='Use hardlink instead of symlink.') parser.add_argument('-n', '--no-tmdb', action='store_true', help='Do not use TMDB API to obtain data.') + parser.add_argument('-S', '--season-number', type=int, help='Season number (optional)') parser.add_argument('input', help='Input directory.') parser.add_argument('output', help='Output directory.') return parser diff --git a/gpt_shows_rename/gpt.py b/gpt_shows_rename/gpt.py index 1b6429c..d4bca9b 100644 --- a/gpt_shows_rename/gpt.py +++ b/gpt_shows_rename/gpt.py @@ -83,7 +83,7 @@ def parse_result(result: str): async def get_response(cfg: Config, inp: str, files: List[str], series_name: str = None, year: int = None, tmdb_id: int = None, - tvdb_id: int = None, tmdb_data: TmdbData = None) -> Files: + tvdb_id: int = None, tmdb_data: TmdbData = None, season_number: int = None) -> Files: prompt = f'The input directory is `{inp}`.' if tmdb_data: prompt += '\n' + gen_tmdb_data(tmdb_data) @@ -96,6 +96,8 @@ async def get_response(cfg: Config, inp: str, files: List[str], prompt += f'\nThe TMDB ID is `{tmdb_id}`.' if tvdb_id: prompt += f'\nThe TVDB ID is `{tvdb_id}`.' + if season_number: + prompt += f'\nThe season number is `{season_number}`.' http_client = httpx.AsyncClient(proxy=cfg.proxy) client = openai.AsyncClient(api_key=cfg.api_key, base_url=cfg.base_url, http_client=http_client) if is_support_structed_output(cfg.model):