diff --git a/aishell/cli.py b/aishell/cli.py index 5a45566..1966acd 100644 --- a/aishell/cli.py +++ b/aishell/cli.py @@ -1,6 +1,9 @@ import os +import sys import webbrowser +import pyperclip +import rich import typer from rich.console import Console from yt_dlp.cookies import SUPPORTED_BROWSERS @@ -18,6 +21,15 @@ def _open_chatgpt_browser(): webbrowser.open(CHATGPT_LOGIN_URL) +def _ask_user_copy_session_token_to_clipboard(session_token: str) -> None: + copy_session_token = typer.confirm('Do you want to copy the session token to your clipboard?') + if copy_session_token: + pyperclip.copy(session_token) + rich.print( + 'Session token copied to clipboard. [bold]`export CHATGPT_SESSION_TOKEN=`[/bold] to set it.' + ) + + @cli_app.command() def ask(question: str, language_model: LanguageModel = LanguageModel.REVERSE_ENGINEERED_CHATGPT): query_client: QueryClient @@ -34,7 +46,13 @@ def ask(question: str, language_model: LanguageModel = LanguageModel.REVERSE_ENG BROWSER_NAME = typer.prompt(f'Which browser did you use to log in? [{SUPPORTED_BROWSERS}]') adapter = OpenAICookieAdapter(BROWSER_NAME) session_token = adapter.get_openai_session_token() - query_client = ReverseEngineeredChatGPTClient(session_token=session_token) + if session_token is not None: + os.environ['CHATGPT_SESSION_TOKEN'] = session_token + _ask_user_copy_session_token_to_clipboard(session_token) + ask(question, language_model) + else: + print('Failed to log in.') + sys.exit() query_client.query(question) @@ -43,9 +61,9 @@ def ask(question: str, language_model: LanguageModel = LanguageModel.REVERSE_ENG f''' [green] AiShell is thinking of `{question}` ...[/green] -[italic]AiShell is not responsible for any damage caused by the command executed by the user.[/italic]'''.strip(), ): +[dim]AiShell is not responsible for any damage caused by the command executed by the user.[/dim]'''.strip(), ): response = query_client.query(question) - console.print(f'[italic]ai$hell: {response}\n') + console.print(f'AiShell: {response}\n') will_execute = typer.confirm('Execute this command?') diff --git a/aishell/models/__init__.py b/aishell/models/__init__.py index 373d3c6..fb49e07 100644 --- a/aishell/models/__init__.py +++ b/aishell/models/__init__.py @@ -1,2 +1,3 @@ from .language_model import LanguageModel as LanguageModel from .open_ai_response_model import OpenAIResponseModel as OpenAIResponseModel +from .revchatgpt_chatbot_config_model import RevChatGPTChatbotConfigModel as RevChatGPTChatbotConfigModel diff --git a/aishell/models/open_ai_response_model.py b/aishell/models/open_ai_response_model.py index 79cdf4f..5937ec6 100644 --- a/aishell/models/open_ai_response_model.py +++ b/aishell/models/open_ai_response_model.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Optional from pydantic import BaseModel @@ -16,9 +16,12 @@ class Usage(BaseModel): prompt_tokens: int total_tokens: int - choices: Optional[List[Choice]] + choices: Optional[list[Choice]] created: int id: str model: str object: str usage: Usage + + class Config: + frozen = True diff --git a/aishell/models/revchatgpt_chatbot_config_model.py b/aishell/models/revchatgpt_chatbot_config_model.py new file mode 100644 index 0000000..bc8e7c9 --- /dev/null +++ b/aishell/models/revchatgpt_chatbot_config_model.py @@ -0,0 +1,23 @@ +from typing import Optional + +from pydantic import BaseModel, root_validator + + +class RevChatGPTChatbotConfigModel(BaseModel): + email: Optional[str] = None + password: Optional[str] = None + session_token: Optional[str] = None + access_token: Optional[str] = None + paid: bool = False + + @root_validator + def check_at_least_one_account_info(cls, values: dict[str, Optional[str]]): + IS_ACCOUNT_LOGIN = values.get('email') and values.get('password') + IS_TOKEN_AUTH = values.get('session_token') or values.get('access_token') + if not IS_ACCOUNT_LOGIN and not IS_TOKEN_AUTH: + raise ValueError('No information for authentication provided.') + + return values + + class Config: + frozen = True diff --git a/aishell/query_clients/gpt3_client.py b/aishell/query_clients/gpt3_client.py index 7a707ec..1a294e0 100644 --- a/aishell/query_clients/gpt3_client.py +++ b/aishell/query_clients/gpt3_client.py @@ -1,5 +1,5 @@ import os -from typing import cast +from typing import Final, cast import openai @@ -11,16 +11,9 @@ class GPT3Client(QueryClient): - def _construct_prompt(self, text: str) -> str: - return f'''User: You are now a translater from human language to {os.uname()[0]} shell command. - No explanation required, respond with only the raw shell command. - What should I type to shell for: {text}, in one line. - - You: ''' - def query(self, prompt: str) -> str: prompt = self._construct_prompt(prompt) - completion: OpenAIResponseModel = cast( # type: ignore [no-any-unimported] + completion: Final[OpenAIResponseModel] = cast( OpenAIResponseModel, openai.Completion.create( engine='text-davinci-003', @@ -32,5 +25,12 @@ def query(self, prompt: str) -> str: ) if not completion.choices or len(completion.choices) == 0 or not completion.choices[0].text: raise RuntimeError('No response from OpenAI') - response_text: str = completion.choices[0].text + response_text: Final[str] = completion.choices[0].text return make_executable_command(response_text) + + def _construct_prompt(self, text: str) -> str: + return f'''User: You are now a translater from human language to {os.uname()[0]} shell command. + No explanation required, respond with only the raw shell command. + What should I type to shell for: {text}, in one line. + + You: ''' diff --git a/aishell/query_clients/official_chatgpt_client.py b/aishell/query_clients/official_chatgpt_client.py index 5632037..0c84a22 100644 --- a/aishell/query_clients/official_chatgpt_client.py +++ b/aishell/query_clients/official_chatgpt_client.py @@ -1,5 +1,5 @@ import os -from typing import Final, Optional +from typing import Optional from revChatGPT.V3 import Chatbot @@ -10,28 +10,27 @@ class OfficialChatGPTClient(QueryClient): - openai_api_key: str def __init__( self, openai_api_key: Optional[str] = None, ): super().__init__() - OPENAI_API_KEY: Final[Optional[str]] = os.environ.get('OPENAI_API_KEY', openai_api_key) + OPENAI_API_KEY: Optional[str] = os.environ.get('OPENAI_API_KEY', openai_api_key) if OPENAI_API_KEY is None: raise UnauthorizedAccessError('OPENAI_API_KEY should not be none') - self.openai_api_key = OPENAI_API_KEY - - def _construct_prompt(self, text: str) -> str: - return f'''You are now a translater from human language to {os.uname()[0]} shell command. - No explanation required, respond with only the raw shell command. - What should I type to shell for: {text}, in one line.''' + self.OPENAI_API_KEY = OPENAI_API_KEY def query(self, prompt: str) -> str: - prompt = self._construct_prompt(prompt) + chatbot = Chatbot(api_key=self.OPENAI_API_KEY) - chatbot = Chatbot(api_key=self.openai_api_key) + prompt = self._construct_prompt(prompt) response_text = chatbot.ask(prompt) + executable_command = make_executable_command(response_text) + return executable_command - return make_executable_command(response_text) + def _construct_prompt(self, text: str) -> str: + return f'''You are now a translater from human language to {os.uname()[0]} shell command. + No explanation required, respond with only the raw shell command. + What should I type to shell for: {text}, in one line.''' diff --git a/aishell/query_clients/reverse_engineered_chatgpt_client.py b/aishell/query_clients/reverse_engineered_chatgpt_client.py index fb3277a..ea0acb3 100644 --- a/aishell/query_clients/reverse_engineered_chatgpt_client.py +++ b/aishell/query_clients/reverse_engineered_chatgpt_client.py @@ -1,16 +1,21 @@ import os -from typing import Optional, cast +from typing import Optional, Union, cast from revChatGPT.V1 import Chatbot from aishell.exceptions import UnauthorizedAccessError +from aishell.models import RevChatGPTChatbotConfigModel from aishell.utils import make_executable_command from .query_client import QueryClient class ReverseEngineeredChatGPTClient(QueryClient): - config: dict[str, str] = {} + _config: RevChatGPTChatbotConfigModel + + @property + def revchatgpt_config(self) -> dict[str, Union[str, bool]]: + return self._config.dict(exclude_none=True) def __init__( self, @@ -19,21 +24,17 @@ def __init__( ): CHATGPT_ACCESS_TOKEN = os.environ.get('CHATGPT_ACCESS_TOKEN', access_token) CHATGPT_SESSION_TOKEN = os.environ.get('CHATGPT_SESSION_TOKEN', session_token) - if CHATGPT_ACCESS_TOKEN is not None: - self.config['access_token'] = CHATGPT_ACCESS_TOKEN - elif CHATGPT_SESSION_TOKEN is not None: - self.config['session_token'] = CHATGPT_SESSION_TOKEN + if CHATGPT_ACCESS_TOKEN: + self._config = RevChatGPTChatbotConfigModel(access_token=CHATGPT_ACCESS_TOKEN) + elif CHATGPT_SESSION_TOKEN: + self._config = RevChatGPTChatbotConfigModel(session_token=CHATGPT_SESSION_TOKEN) else: raise UnauthorizedAccessError('No access token or session token provided.') - def _construct_prompt(self, text: str) -> str: - return f'''You are now a translater from human language to {os.uname()[0]} shell command. - No explanation required, respond with only the raw shell command. - What should I type to shell for: {text}, in one line.''' - def query(self, prompt: str) -> str: prompt = self._construct_prompt(prompt) - chatbot = Chatbot(config=self.config) + chatbot = Chatbot(config=self.revchatgpt_config) # pyright: ignore [reportGeneralTypeIssues] + # ignore for wrong type hint of revchatgpt response_text = '' for data in chatbot.ask(prompt): @@ -42,3 +43,8 @@ def query(self, prompt: str) -> str: response_text = make_executable_command(cast(str, response_text)) return response_text + + def _construct_prompt(self, text: str) -> str: + return f'''You are now a translater from human language to {os.uname()[0]} shell command. + No explanation required, respond with only the raw shell command. + What should I type to shell for: {text}, in one line.''' diff --git a/aishell/utils/str_enum.py b/aishell/utils/str_enum.py index fa4d763..016a6e1 100644 --- a/aishell/utils/str_enum.py +++ b/aishell/utils/str_enum.py @@ -4,7 +4,12 @@ class StrEnum(str, Enum): - def _generate_next_value_(name: str, start: int, count: int, last_values: list[Any]): # type: ignore + def _generate_next_value_( # pyright: ignore [reportIncompatibleMethodOverride], for pyright's bug + name: str, + start: int, + count: int, + last_values: list[Any], + ): return name.lower() def __repr__(self): diff --git a/poetry.lock b/poetry.lock index d39e033..25bc556 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1055,25 +1055,6 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker perf = ["ipython"] testing = ["flake8 (<5)", "flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)"] -[[package]] -name = "importlib-resources" -version = "5.10.2" -description = "Read resources from Python packages" -category = "main" -optional = false -python-versions = ">=3.7" -files = [ - {file = "importlib_resources-5.10.2-py3-none-any.whl", hash = "sha256:7d543798b0beca10b6a01ac7cafda9f822c54db9e8376a6bf57e0cbd74d486b6"}, - {file = "importlib_resources-5.10.2.tar.gz", hash = "sha256:e4a96c8cc0339647ff9a5e0550d9f276fc5a01ffa276012b58ec108cfd7b8484"}, -] - -[package.dependencies] -zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["flake8 (<5)", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] - [[package]] name = "iniconfig" version = "2.0.0" @@ -1147,8 +1128,6 @@ files = [ [package.dependencies] attrs = ">=17.4.0" -importlib-resources = {version = ">=1.4.0", markers = "python_version < \"3.9\""} -pkgutil-resolve-name = {version = ">=1.3.10", markers = "python_version < \"3.9\""} pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2" [package.extras] @@ -1169,7 +1148,6 @@ files = [ [package.dependencies] importlib-metadata = {version = ">=4.11.4", markers = "python_version < \"3.12\""} -importlib-resources = {version = "*", markers = "python_version < \"3.9\""} "jaraco.classes" = "*" jeepney = {version = ">=0.4.2", markers = "sys_platform == \"linux\""} pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""} @@ -1549,18 +1527,6 @@ files = [ [package.extras] testing = ["pytest", "pytest-cov"] -[[package]] -name = "pkgutil-resolve-name" -version = "1.3.10" -description = "Resolve a name to an object." -category = "main" -optional = false -python-versions = ">=3.6" -files = [ - {file = "pkgutil_resolve_name-1.3.10-py3-none-any.whl", hash = "sha256:ca27cc078d25c5ad71a9de0a7a330146c4e014c2462d9af19c6b828280649c5e"}, - {file = "pkgutil_resolve_name-1.3.10.tar.gz", hash = "sha256:357d6c9e6a755653cfd78893817c0853af365dd51ec97f3d358a819373bbd174"}, -] - [[package]] name = "platformdirs" version = "2.6.2" @@ -1813,6 +1779,17 @@ files = [ [package.extras] plugins = ["importlib-metadata"] +[[package]] +name = "pyperclip" +version = "1.8.2" +description = "A cross-platform clipboard module for Python. (Only handles plain text for now.)" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "pyperclip-1.8.2.tar.gz", hash = "sha256:105254a8b04934f0bc84e9c24eb360a591aaf6535c9def5f29d92af107a9bf57"}, +] + [[package]] name = "pyright" version = "1.1.294" @@ -2232,7 +2209,6 @@ files = [ [package.dependencies] commonmark = ">=0.9.0,<0.10.0" pygments = ">=2.6.0,<3.0.0" -typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.9\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] @@ -2894,5 +2870,5 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [metadata] lock-version = "2.0" -python-versions = "^3.8" -content-hash = "ce5e5cb95fdc926de298a3a404768c515aa11343d6e97404afa9cde2f96003d4" +python-versions = "^3.9" +content-hash = "36d170caea5668ca3ba7f0f23ed456c6e51c6e28b9d099ca43d7e63e9cea1c82" diff --git a/pyproject.toml b/pyproject.toml index 6969f6b..b669d30 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,6 +47,7 @@ openai = "^0.26.5" pydantic = "^1.10.4" pyright = "^1.1.294" yt-dlp = "^2023.2.17" +pyperclip = "^1.8.2" [tool.poetry.scripts] aishell = "aishell:main"