Skip to content

fix the function usage in the parser #21

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: pypi/0.0.0-alpha
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backend/preprocessor/antlr_parser/parse_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def parse_pg_tree(src_sql: str) -> (str, int, int, str):
except SelfParseError as e:
return None, e.line, e.column, e.msg
except Exception as e:
logging.error(f"An error occurred: {e}", file=sys.stderr)
logging.error(f"An error occurred: {e}")
raise e


Expand All @@ -62,7 +62,7 @@ def parse_mysql_tree(src_sql: str):
except SelfParseError as e:
return None, e.line, e.column, e.msg
except Exception as e:
logging.error(f"An error occurred: {e}", file=sys.stderr)
logging.error(f"An error occurred: {e}")
raise e


Expand All @@ -79,7 +79,7 @@ def parse_oracle_tree(src_sql: str):
except SelfParseError as e:
return None, e.line, e.column, e.msg
except Exception as e:
logging.error(f"An error occurred: {e}", file=sys.stderr)
logging.error(f"An error occurred: {e}")
return None, -1, -1, ''


Expand Down
77 changes: 74 additions & 3 deletions backend/preprocessor/antlr_parser/pg_parser/PostgreSQLParser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Generated from /data/Coding/LLM4DB/antlr_gram/pg/PostgreSQLParser.g4 by ANTLR 4.13.1
# encoding: utf-8
from antlr4 import *
from preprocessor.antlr_parser.pg_parser.PostgreSQLLexer import PostgreSQLLexer
from preprocessor.antlr_parser.pg_parser.LexerDispatchingErrorListener import LexerDispatchingErrorListener
from preprocessor.antlr_parser.pg_parser.ParserDispatchingErrorListener import ParserDispatchingErrorListener
from io import StringIO
import sys
if sys.version_info[1] > 5:
Expand Down Expand Up @@ -6680,8 +6683,8 @@ class PostgreSQLParser ( Parser ):
EndDollarStringConstant=678
AfterEscapeStringConstantWithNewlineMode_Continued=679

def __init__(self, input:TokenStream, output:TextIO = sys.stdout):
super().__init__(input, output)
def __init__(self, input:TokenStream):
super().__init__(input)
self.checkVersion("4.13.1")
self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache)
self._predicates = None
Expand Down Expand Up @@ -35395,7 +35398,7 @@ def createfunc_opt_list(self):
self._errHandler.sync(self)
_alt = self._interp.adaptivePredict(self._input,286,self._ctx)

ParseRoutineBody(_localctx)
self.ParseRoutineBody(localctx)

except RecognitionException as re:
localctx.exception = re
Expand Down Expand Up @@ -76297,6 +76300,74 @@ def b_expr_sempred(self, localctx:B_exprContext, predIndex:int):

if predIndex == 7:
return self.precpred(self._ctx, 1)

def ParseRoutineBody(self, _localctx):
lang = None
for coi in _localctx.createfunc_opt_item():
if coi.LANGUAGE() is not None:
if coi.nonreservedword_or_sconst() is not None:
if coi.nonreservedword_or_sconst().nonreservedword() is not None:
if coi.nonreservedword_or_sconst().nonreservedword().identifier() is not None:
if coi.nonreservedword_or_sconst().nonreservedword().identifier().Identifier() is not None:
lang = coi.nonreservedword_or_sconst().nonreservedword().identifier().Identifier().getText()
break

if lang is None:
return

func_as = None
for a in _localctx.createfunc_opt_item():
if a.func_as() is not None:
func_as = a
break

if func_as is not None:
txt = self.get_routine_body_string(func_as.func_as().sconst(0))
postgreSQL_parser = self.get_postgresql_parser(txt)
if lang == "plpgsql":
func_as.func_as().Definition = postgreSQL_parser.plsqlroot()
elif lang == "sql":
func_as.func_as().Definition = postgreSQL_parser.root()

def trim_quotes(self, s: str) -> str:
return s[1:-1] if s and len(s) > 1 else s

def unquote(self, s: str) -> str:
slength = len(s)
r = []
i = 0
while i < slength:
c = s[i]
r.append(c)
if c == '\'' and i < slength - 1 and s[i + 1] == '\'':
i += 1
i += 1
return ''.join(r)

def get_routine_body_string(self, rule) -> str:
anysconst = rule.anysconst()
string_constant = anysconst.StringConstant()
if string_constant is not None:
return self.unquote(self.trim_quotes(string_constant.getText()))
unicode_escape_string_constant = anysconst.UnicodeEscapeStringConstant()
if unicode_escape_string_constant is not None:
return self.trim_quotes(unicode_escape_string_constant.getText())
escape_string_constant = anysconst.EscapeStringConstant()
if escape_string_constant is not None:
return self.trim_quotes(escape_string_constant.getText())
result = ''
dollar_text = anysconst.DollarText()
for s in dollar_text:
result += s.getText()
return result

def get_postgresql_parser(self, txt):
input_stream = InputStream(txt)
lexer = PostgreSQLLexer(input_stream)
token_stream = CommonTokenStream(lexer)
parser = PostgreSQLParser(token_stream)
return parser




Expand Down
Original file line number Diff line number Diff line change
@@ -1,75 +1,12 @@
from antlr4 import *
from antlr4.CommonTokenStream import CommonTokenStream

from preprocessor.antlr_parser.pg_parser.LexerDispatchingErrorListener import LexerDispatchingErrorListener
from preprocessor.antlr_parser.pg_parser.ParserDispatchingErrorListener import ParserDispatchingErrorListener

from PostgreSQLLexer import PostgreSQLLexer
from PostgreSQLParser import PostgreSQLParser


class PostgreSQLParserBase(Parser):

def __init__(self, input_stream: TokenStream):
super().__init__(input_stream)

def parse_routine_body(self, _localctx: PostgreSQLParser.Createfunc_opt_listContext):
lang = None
for coi in _localctx.createfunc_opt_item():
if coi.LANGUAGE() is not None:
if coi.nonreservedword_or_sconst() is not None:
if coi.nonreservedword_or_sconst().nonreservedword() is not None:
if coi.nonreservedword_or_sconst().nonreservedword().identifier() is not None:
if coi.nonreservedword_or_sconst().nonreservedword().identifier().Identifier() is not None:
lang = coi.nonreservedword_or_sconst().nonreservedword().identifier().Identifier().getText()
break

if lang is None:
return

func_as = None
for a in _localctx.createfunc_opt_item():
if a.func_as() is not None:
func_as = a
break

if func_as is not None:
txt = self.get_routine_body_string(func_as.func_as().sconst(0))
postgreSQL_parser = self.get_postgresql_parser(txt)
if lang == "plpgsql":
func_as.func_as().Definition = postgreSQL_parser.plsqlroot()
elif lang == "sql":
func_as.func_as().Definition = postgreSQL_parser.root()

def trim_quotes(self, s: str) -> str:
return s[1:-1] if s and len(s) > 1 else s

def unquote(self, s: str) -> str:
slength = len(s)
r = []
i = 0
while i < slength:
c = s[i]
r.append(c)
if c == '\'' and i < slength - 1 and s[i + 1] == '\'':
i += 1
i += 1
return ''.join(r)

def get_routine_body_string(self, rule: PostgreSQLParser.SconstContext) -> str:
anysconst = rule.anysconst()
string_constant = anysconst.StringConstant()
if string_constant is not None:
return self.unquote(self.trim_quotes(string_constant.getText()))
unicode_escape_string_constant = anysconst.UnicodeEscapeStringConstant()
if unicode_escape_string_constant is not None:
return self.trim_quotes(unicode_escape_string_constant.getText())
escape_string_constant = anysconst.EscapeStringConstant()
if escape_string_constant is not None:
return self.trim_quotes(escape_string_constant.getText())
result = ''
dollar_text = anysconst.DollarText()
for s in dollar_text:
result += s.getText()
return result


9 changes: 5 additions & 4 deletions backend/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from translator.translate_prompt import SYSTEM_PROMPT_NA, USER_PROMPT_NA, \
SYSTEM_PROMPT_SEG, USER_PROMPT_SEG, SYSTEM_PROMPT_RET, USER_PROMPT_RET, EXAMPLE_PROMPT, JUDGE_INFO_PROMPT
from utils.constants import DIALECT_MAP, FAILED_TEMPLATE, CHUNK_SIZE, TRANSLATION_ANSWER_PATTERN, \
JUDGE_ANSWER_PATTERN, DIALECT_LIST, DIALECT_LIST_RULE
JUDGE_ANSWER_PATTERN, DIALECT_LIST, DIALECT_LIST_RULE, DIALECT_ABBREVIATIONS
from utils.tools import process_err_msg, process_history_text
from vector_store.chroma_store import ChromaStore

Expand Down Expand Up @@ -996,10 +996,11 @@ def main():
tgt_db_config=tgt_db_config, vector_config=vector_config,
history_id=None, out_type="file", out_dir=args.out_dir,
retrieval_on=args.retrieval_on, top_k=args.top_k)

full_name_src = DIALECT_ABBREVIATIONS.get(args.src_dialect, args.src_dialect)
full_name_tgt = DIALECT_ABBREVIATIONS.get(args.tgt_dialect, args.tgt_dialect)
if not tgt_db_config or not vector_config:
if not args.llm_model_name and (args.src_dialect in DIALECT_LIST_RULE
or args.tgt_dialect not in DIALECT_LIST_RULE):
if not args.llm_model_name and (full_name_src in DIALECT_LIST_RULE
and full_name_tgt in DIALECT_LIST_RULE):
translated_sql, model_ans_list, \
used_pieces, lift_histories = translator.rule_rewrite()
else:
Expand Down
6 changes: 5 additions & 1 deletion backend/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
DIALECT_LIST_RULE = ["athena", "bigquery", "clickhouse", "databricks", "doris", "drill", "druid",
"duckdb", "dune", "hive", "materialize", "mysql", "oracle", "postgres",
"presto", "prql", "redshift", "risingwave", "snowflake", "spark", "spark2",
"sqlite", "starrocks", "tableau", "teradata", "trino", "tsql", "postgresql"]
"sqlite", "starrocks", "tableau", "teradata", "trino", "tsql"]
DIALECT_MAP = {
'pg': 'PostgreSQL 14.7',
'mysql': "MySQL 8.4",
'oracle': "Oracle 11g"
}

DIALECT_ABBREVIATIONS = {
'pg': 'postgres'
}

ORACLE_COMMAND_OPEN = False

TRANSLATION_RESULT_TEMP = r"""
Expand Down