From e7c942fabd2dcddd191f5d29c7f542e1cd797de5 Mon Sep 17 00:00:00 2001 From: BoyuXiao <2092327901@qq.com> Date: Tue, 10 Jun 2025 15:54:32 +0800 Subject: [PATCH 1/2] fix the function usage in the parser --- .../preprocessor/antlr_parser/parse_tree.py | 6 +- .../pg_parser/PostgreSQLParser.py | 77 ++++++++++++++++++- .../pg_parser/PostgreSQLParserBase.py | 65 +--------------- 3 files changed, 78 insertions(+), 70 deletions(-) diff --git a/backend/preprocessor/antlr_parser/parse_tree.py b/backend/preprocessor/antlr_parser/parse_tree.py index 64b8629..6ceb2a6 100644 --- a/backend/preprocessor/antlr_parser/parse_tree.py +++ b/backend/preprocessor/antlr_parser/parse_tree.py @@ -45,7 +45,7 @@ def parse_pg_tree(src_sql: str) -> (str, int, int, str): except SelfParseError as e: return None, e.line, e.column, e.msg except Exception as e: - logging.error(f"An error occurred: {e}", file=sys.stderr) + logging.error(f"An error occurred: {e}") raise e @@ -62,7 +62,7 @@ def parse_mysql_tree(src_sql: str): except SelfParseError as e: return None, e.line, e.column, e.msg except Exception as e: - logging.error(f"An error occurred: {e}", file=sys.stderr) + logging.error(f"An error occurred: {e}") raise e @@ -79,7 +79,7 @@ def parse_oracle_tree(src_sql: str): except SelfParseError as e: return None, e.line, e.column, e.msg except Exception as e: - logging.error(f"An error occurred: {e}", file=sys.stderr) + logging.error(f"An error occurred: {e}") return None, -1, -1, '' diff --git a/backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParser.py b/backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParser.py index 8b5d1dc..fe49044 100644 --- a/backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParser.py +++ b/backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParser.py @@ -1,6 +1,9 @@ # Generated from /data/Coding/LLM4DB/antlr_gram/pg/PostgreSQLParser.g4 by ANTLR 4.13.1 # encoding: utf-8 from antlr4 import * +from preprocessor.antlr_parser.pg_parser.PostgreSQLLexer import PostgreSQLLexer +from preprocessor.antlr_parser.pg_parser.LexerDispatchingErrorListener import LexerDispatchingErrorListener +from preprocessor.antlr_parser.pg_parser.ParserDispatchingErrorListener import ParserDispatchingErrorListener from io import StringIO import sys if sys.version_info[1] > 5: @@ -6680,8 +6683,8 @@ class PostgreSQLParser ( Parser ): EndDollarStringConstant=678 AfterEscapeStringConstantWithNewlineMode_Continued=679 - def __init__(self, input:TokenStream, output:TextIO = sys.stdout): - super().__init__(input, output) + def __init__(self, input:TokenStream): + super().__init__(input) self.checkVersion("4.13.1") self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) self._predicates = None @@ -35395,7 +35398,7 @@ def createfunc_opt_list(self): self._errHandler.sync(self) _alt = self._interp.adaptivePredict(self._input,286,self._ctx) - ParseRoutineBody(_localctx) + self.ParseRoutineBody(localctx) except RecognitionException as re: localctx.exception = re @@ -76297,6 +76300,74 @@ def b_expr_sempred(self, localctx:B_exprContext, predIndex:int): if predIndex == 7: return self.precpred(self._ctx, 1) + + def ParseRoutineBody(self, _localctx): + lang = None + for coi in _localctx.createfunc_opt_item(): + if coi.LANGUAGE() is not None: + if coi.nonreservedword_or_sconst() is not None: + if coi.nonreservedword_or_sconst().nonreservedword() is not None: + if coi.nonreservedword_or_sconst().nonreservedword().identifier() is not None: + if coi.nonreservedword_or_sconst().nonreservedword().identifier().Identifier() is not None: + lang = coi.nonreservedword_or_sconst().nonreservedword().identifier().Identifier().getText() + break + + if lang is None: + return + + func_as = None + for a in _localctx.createfunc_opt_item(): + if a.func_as() is not None: + func_as = a + break + + if func_as is not None: + txt = self.get_routine_body_string(func_as.func_as().sconst(0)) + postgreSQL_parser = self.get_postgresql_parser(txt) + if lang == "plpgsql": + func_as.func_as().Definition = postgreSQL_parser.plsqlroot() + elif lang == "sql": + func_as.func_as().Definition = postgreSQL_parser.root() + + def trim_quotes(self, s: str) -> str: + return s[1:-1] if s and len(s) > 1 else s + + def unquote(self, s: str) -> str: + slength = len(s) + r = [] + i = 0 + while i < slength: + c = s[i] + r.append(c) + if c == '\'' and i < slength - 1 and s[i + 1] == '\'': + i += 1 + i += 1 + return ''.join(r) + + def get_routine_body_string(self, rule) -> str: + anysconst = rule.anysconst() + string_constant = anysconst.StringConstant() + if string_constant is not None: + return self.unquote(self.trim_quotes(string_constant.getText())) + unicode_escape_string_constant = anysconst.UnicodeEscapeStringConstant() + if unicode_escape_string_constant is not None: + return self.trim_quotes(unicode_escape_string_constant.getText()) + escape_string_constant = anysconst.EscapeStringConstant() + if escape_string_constant is not None: + return self.trim_quotes(escape_string_constant.getText()) + result = '' + dollar_text = anysconst.DollarText() + for s in dollar_text: + result += s.getText() + return result + + def get_postgresql_parser(self, txt): + input_stream = InputStream(txt) + lexer = PostgreSQLLexer(input_stream) + token_stream = CommonTokenStream(lexer) + parser = PostgreSQLParser(token_stream) + return parser + diff --git a/backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParserBase.py b/backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParserBase.py index ec67c07..ace874c 100644 --- a/backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParserBase.py +++ b/backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParserBase.py @@ -1,75 +1,12 @@ from antlr4 import * from antlr4.CommonTokenStream import CommonTokenStream -from preprocessor.antlr_parser.pg_parser.LexerDispatchingErrorListener import LexerDispatchingErrorListener -from preprocessor.antlr_parser.pg_parser.ParserDispatchingErrorListener import ParserDispatchingErrorListener + from PostgreSQLLexer import PostgreSQLLexer from PostgreSQLParser import PostgreSQLParser - class PostgreSQLParserBase(Parser): def __init__(self, input_stream: TokenStream): super().__init__(input_stream) - - def parse_routine_body(self, _localctx: PostgreSQLParser.Createfunc_opt_listContext): - lang = None - for coi in _localctx.createfunc_opt_item(): - if coi.LANGUAGE() is not None: - if coi.nonreservedword_or_sconst() is not None: - if coi.nonreservedword_or_sconst().nonreservedword() is not None: - if coi.nonreservedword_or_sconst().nonreservedword().identifier() is not None: - if coi.nonreservedword_or_sconst().nonreservedword().identifier().Identifier() is not None: - lang = coi.nonreservedword_or_sconst().nonreservedword().identifier().Identifier().getText() - break - - if lang is None: - return - - func_as = None - for a in _localctx.createfunc_opt_item(): - if a.func_as() is not None: - func_as = a - break - - if func_as is not None: - txt = self.get_routine_body_string(func_as.func_as().sconst(0)) - postgreSQL_parser = self.get_postgresql_parser(txt) - if lang == "plpgsql": - func_as.func_as().Definition = postgreSQL_parser.plsqlroot() - elif lang == "sql": - func_as.func_as().Definition = postgreSQL_parser.root() - - def trim_quotes(self, s: str) -> str: - return s[1:-1] if s and len(s) > 1 else s - - def unquote(self, s: str) -> str: - slength = len(s) - r = [] - i = 0 - while i < slength: - c = s[i] - r.append(c) - if c == '\'' and i < slength - 1 and s[i + 1] == '\'': - i += 1 - i += 1 - return ''.join(r) - - def get_routine_body_string(self, rule: PostgreSQLParser.SconstContext) -> str: - anysconst = rule.anysconst() - string_constant = anysconst.StringConstant() - if string_constant is not None: - return self.unquote(self.trim_quotes(string_constant.getText())) - unicode_escape_string_constant = anysconst.UnicodeEscapeStringConstant() - if unicode_escape_string_constant is not None: - return self.trim_quotes(unicode_escape_string_constant.getText()) - escape_string_constant = anysconst.EscapeStringConstant() - if escape_string_constant is not None: - return self.trim_quotes(escape_string_constant.getText()) - result = '' - dollar_text = anysconst.DollarText() - for s in dollar_text: - result += s.getText() - return result - \ No newline at end of file From 50b8e093181c65467defaa67a5dfe6d3e53246a8 Mon Sep 17 00:00:00 2001 From: BoyuXiao <2092327901@qq.com> Date: Thu, 12 Jun 2025 22:22:40 +0800 Subject: [PATCH 2/2] fix the bug about rule_rewrite --- backend/translate.py | 9 +++++---- backend/utils/constants.py | 6 +++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/backend/translate.py b/backend/translate.py index 2199f55..7b4b856 100644 --- a/backend/translate.py +++ b/backend/translate.py @@ -26,7 +26,7 @@ from translator.translate_prompt import SYSTEM_PROMPT_NA, USER_PROMPT_NA, \ SYSTEM_PROMPT_SEG, USER_PROMPT_SEG, SYSTEM_PROMPT_RET, USER_PROMPT_RET, EXAMPLE_PROMPT, JUDGE_INFO_PROMPT from utils.constants import DIALECT_MAP, FAILED_TEMPLATE, CHUNK_SIZE, TRANSLATION_ANSWER_PATTERN, \ - JUDGE_ANSWER_PATTERN, DIALECT_LIST, DIALECT_LIST_RULE + JUDGE_ANSWER_PATTERN, DIALECT_LIST, DIALECT_LIST_RULE, DIALECT_ABBREVIATIONS from utils.tools import process_err_msg, process_history_text from vector_store.chroma_store import ChromaStore @@ -996,10 +996,11 @@ def main(): tgt_db_config=tgt_db_config, vector_config=vector_config, history_id=None, out_type="file", out_dir=args.out_dir, retrieval_on=args.retrieval_on, top_k=args.top_k) - + full_name_src = DIALECT_ABBREVIATIONS.get(args.src_dialect, args.src_dialect) + full_name_tgt = DIALECT_ABBREVIATIONS.get(args.tgt_dialect, args.tgt_dialect) if not tgt_db_config or not vector_config: - if not args.llm_model_name and (args.src_dialect in DIALECT_LIST_RULE - or args.tgt_dialect not in DIALECT_LIST_RULE): + if not args.llm_model_name and (full_name_src in DIALECT_LIST_RULE + and full_name_tgt in DIALECT_LIST_RULE): translated_sql, model_ans_list, \ used_pieces, lift_histories = translator.rule_rewrite() else: diff --git a/backend/utils/constants.py b/backend/utils/constants.py index 0a85364..bea976e 100644 --- a/backend/utils/constants.py +++ b/backend/utils/constants.py @@ -17,13 +17,17 @@ DIALECT_LIST_RULE = ["athena", "bigquery", "clickhouse", "databricks", "doris", "drill", "druid", "duckdb", "dune", "hive", "materialize", "mysql", "oracle", "postgres", "presto", "prql", "redshift", "risingwave", "snowflake", "spark", "spark2", - "sqlite", "starrocks", "tableau", "teradata", "trino", "tsql", "postgresql"] + "sqlite", "starrocks", "tableau", "teradata", "trino", "tsql"] DIALECT_MAP = { 'pg': 'PostgreSQL 14.7', 'mysql': "MySQL 8.4", 'oracle': "Oracle 11g" } +DIALECT_ABBREVIATIONS = { + 'pg': 'postgres' +} + ORACLE_COMMAND_OPEN = False TRANSLATION_RESULT_TEMP = r"""