From 3e539bfdea02a268515cbbb60b593ba64c645de6 Mon Sep 17 00:00:00 2001 From: ksyasuda Date: Wed, 9 Aug 2023 20:05:29 -0700 Subject: [PATCH] update pydb --- pydb/.gitignore | 70 +-- pydb/LICENSE | 19 - pydb/pyproject.toml | 29 +- pydb/requirements.txt | 36 ++ pydb/src/pydb/__init__.py | 3 + pydb/src/pydb/db.py | 519 +++++++++++++++++++++ pydb/src/pydb/factory/__init__.py | 1 + pydb/src/pydb/factory/db_factory.py | 31 ++ pydb/src/pydb/managers/__init__.py | 1 + pydb/src/pydb/managers/database_manager.py | 21 + pydb/src/pydb/utils.py | 275 +++++++++++ 11 files changed, 917 insertions(+), 88 deletions(-) create mode 100644 pydb/requirements.txt create mode 100644 pydb/src/pydb/__init__.py create mode 100755 pydb/src/pydb/db.py create mode 100644 pydb/src/pydb/factory/__init__.py create mode 100644 pydb/src/pydb/factory/db_factory.py create mode 100644 pydb/src/pydb/managers/__init__.py create mode 100644 pydb/src/pydb/managers/database_manager.py create mode 100644 pydb/src/pydb/utils.py diff --git a/pydb/.gitignore b/pydb/.gitignore index ebc86a6..9bf8bd0 100644 --- a/pydb/.gitignore +++ b/pydb/.gitignore @@ -1,60 +1,10 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class -config.py - -# C extensions -*.so - -# Distribution / packaging -.Python -env/ -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -*.egg-info/ -.installed.cfg -*.egg - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*,cover -.hypothesis/ - -# Translations -*.mo -*.pot - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Vim. -*.swp +.env +.DS_Store +.idea +*.log +tmp/ +env/* +dist/* +__pycache__* +*.egg-info +.vscode/* diff --git a/pydb/LICENSE b/pydb/LICENSE index 96f1555..e69de29 100644 --- a/pydb/LICENSE +++ b/pydb/LICENSE @@ -1,19 +0,0 @@ -Copyright (c) 2018 The Python Packaging Authority - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/pydb/pyproject.toml b/pydb/pyproject.toml index 3b429bf..6955e70 100644 --- a/pydb/pyproject.toml +++ b/pydb/pyproject.toml @@ -4,20 +4,31 @@ build-backend = "hatchling.build" [project] name = "pydb" -version = "0.0.3" -authors = [ - { name="Kyle Yasuda", email="suda@sudacode.com" }, +dependencies = [ + "mysql-connector-python", + "cx_oracle", + "snowflake-connector-python", + "pandas", + "python-logger", + "email-sender-simple", + "sqlparse", + "python-dotenv" ] -description = "A python database wrapper" +version = "0.5.17" +authors = [ + { name="Kyle Yasuda", email="kyasuda@westlakefinancial.com" }, +] +description = "Database helper" readme = "README.md" requires-python = ">=3.7" classifiers = [ "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ] -dependencies = ["mysql-connector-python"] -# [project.urls] -# "Homepage" = "https://gitea.suda.codes/sudacode/pydb" -# "Bug Tracker" = "https://gitea.suda.codes/sudacode/pydb/issues" +# [tool.setuptools.packages.find] +# where = ["pydb"] + +[project.urls] +"Homepage" = "https://gitlab.westlakefinancial.com/data_engineering/python_package_registry/-/tree/main/pydb" +"Bug Tracker" = "https://gitlab.westlakefinancial.com/data_engineering/python_package_registry/-/issues" diff --git a/pydb/requirements.txt b/pydb/requirements.txt new file mode 100644 index 0000000..486be1e --- /dev/null +++ b/pydb/requirements.txt @@ -0,0 +1,36 @@ +asn1crypto==1.5.1 +certifi==2022.12.7 +cffi==1.15.1 +charset-normalizer==2.1.1 +colored-output==0.0.1 +cryptography==40.0.1 +cx-Oracle==8.3.0 +email-sender-simple==0.2.8 +exceptiongroup==1.1.1 +filelock==3.11.0 +idna==3.4 +iniconfig==2.0.0 +mysql-connector-python==8.0.32 +numpy==1.24.2 +oscrypto==1.3.0 +packaging==23.0 +pandas==2.0.0 +pluggy==1.0.0 +protobuf==3.20.3 +pycparser==2.21 +pycryptodomex==3.17 +PyJWT==2.6.0 +pyOpenSSL==23.1.1 +pytest==7.3.1 +python-dateutil==2.8.2 +python-dotenv==1.0.0 +python-logger==0.1.13 +pytz==2023.3 +requests==2.28.2 +six==1.16.0 +snowflake-connector-python==3.0.2 +sqlparse==0.4.4 +tomli==2.0.1 +typing_extensions==4.5.0 +tzdata==2023.3 +urllib3==1.26.15 diff --git a/pydb/src/pydb/__init__.py b/pydb/src/pydb/__init__.py new file mode 100644 index 0000000..a2c5b74 --- /dev/null +++ b/pydb/src/pydb/__init__.py @@ -0,0 +1,3 @@ +from pydb.factory import db_factory +from pydb.managers import DatabaseManager +from pydb.utils import pretty_print_sql, query_to_csv, results_to_csv diff --git a/pydb/src/pydb/db.py b/pydb/src/pydb/db.py new file mode 100755 index 0000000..67971bf --- /dev/null +++ b/pydb/src/pydb/db.py @@ -0,0 +1,519 @@ +"Python Database Helper" +# import mariadb +import json +from collections.abc import Iterable +from copy import deepcopy +from email.errors import MessageError +from smtplib import SMTPException +from typing import Dict, Tuple, Union + +import cx_Oracle +import mysql.connector +import snowflake.connector +from email_sender_simple import send_email as email +from pandas import DataFrame + +import pydb.utils as utils + +# DB_ROWS = Union[Iterable[], tuple[int, float, str]] +DB_ROWS = Union[Iterable, Tuple[int, float, str]] +DB_INFO = Union[Dict[str, int], Dict[str, str]] +DB_ARGS = Union[int, float, str, Tuple[int, float, str], Iterable] +QUERY_RESULTS = Union[DB_ROWS, None] + + +class DB_WRAPPER: + """Base Class for DB Connection Wrapper""" + + def __init__(self, info: DB_INFO, connector, exception=None): + """DB Connection Wrapper Base Class Constructor + ---------- + Parameters + ---------- + info : DB_INFO + Dictionary containing connection information for the corresponding + database connector. + Example: + ```python + db_info = { + "host": "hostname", + "user": "username", + "password": "password" + } + ``` + connector : database connector functor + Database connector function + exception : Exception, optional + Exception to pass to databsae connector, by default None + """ + try: + self._conn = connector(**info) + self._cur = self._conn.cursor() + self._Exception = exception if exception is not None else Exception + except Exception as exception: + print("Something went wrong connection to DB:", exception) + raise exception + + def get_connection(self): + """ + Returns the connection object for the DB + + Returns + ------- + connection + Connection object for the DB + """ + return self._conn + + def get_cursor(self): + """ + Returns a reference to the cursor object + + Returns + ------- + cursor + Cursor object for the DB + """ + return self._cur + + def get_exception(self): + """ + Returns the exception handler for the class + + Returns + ------- + Exception + Exception handler for the class + """ + return self._Exception + + def execute( + self, + stmt: str, + args=None, + dictionary=False, + dataframe=False, + dataframe_headers=None, + logger=None, + to_csv=None, + ) -> QUERY_RESULTS: + """ + Queries the db with . Returns the results. + Commits the transaction if successful. + ---------- + Parameters + ---------- + stmt : str + Query to execute + args : int | float | str | List[int, float, str], optional + Arguments to pass to query, by default None + dictionary : bool, optional + Convert query result to dictionary, by default False + dataframe : bool, optional + Convert query result to dataframe, by default False + dataframe_headers : list, optional + List of column headers for dataframe (can only be used when + dataframe=True), by default None + logger : Logger, optional + Logger object, by default None + to_csv : str, optional + Path to save query results as csv, by default None + + Returns + ------- + QUERY_RESULTS + Query result + """ + if dictionary and dataframe: + raise Exception("Cannot specify both dictionary and dataframe") + if dictionary and to_csv: + raise Exception("Cannot specify both dictionary and to_csv") + try: + if args is not None: + # remove any surrounding quotes from the args + if not isinstance(args, (list, tuple)): + args = [args] + utils.lg(logger, f"Args: {args}", "DEBUG") + utils.pretty_print_sql(stmt, logger) + self._cur.execute(stmt, args) + else: + utils.lg(logger, "No args specified", "DEBUG") + self._cur.execute(stmt) + self._conn.commit() + utils.lg(logger, f"Finished executing query: {stmt}") + if to_csv is not None and to_csv != "": + res: DataFrame = utils.convert_query_result( + self._cur, is_dataframe=True, logger=logger + ) + return utils.results_to_csv(res, to_csv, logger=logger) + if "SELECT" in stmt: + return utils.convert_query_result( + self._cur, dictionary, dataframe, dataframe_headers, logger + ) + return None + except self._Exception as exception: + utils.lg( + logger, + f"Something went wrong: {exception}", + "ERROR", + ) + raise exception + + def query( + self, + stmt: str, + args=None, + dictionary=False, + dataframe=False, + dataframe_headers=None, + logger=None, + to_csv=None, + ): + """ + Queries the db with . Returns the results + Does not commit the transaction upon completion + ---------- + Parameters + ---------- + stmt : str + Query to execute + args : int | float | str | List[int, float, str], optional + Arguments to pass to query, by default None + dictionary : bool, optional + Convert query result to dictionary, by default False + dataframe : bool, optional + Convert query result to dataframe, by default False + dataframe_headers : list, optional + List of column headers for dataframe (can only be used when + dataframe=True), by default None + logger : Logger, optional + Logger object, by default None + to_csv : str, optional + Path to save query results as csv, by default None + + Returns + ------- + QUERY_RESULTS (int | float | str | List[int, float, str]) + Query result + """ + if dictionary and dataframe: + raise Exception("Cannot specify both dictionary and dataframe") + try: + if args is not None: + # remove any surrounding quotes from the args + if not isinstance(args, (list, tuple)): + args = [args] + utils.lg(logger, f"Args: {args}", "DEBUG") + utils.pretty_print_sql(stmt, logger) + self._cur.execute(stmt, args) + else: + utils.lg(logger, "No args specified", "DEBUG") + self._cur.execute(stmt) + utils.lg(logger, f"Finished executing query: {stmt}") + if to_csv is not None and to_csv != "": + res: DataFrame = utils.convert_query_result( + self._cur, is_dataframe=True, logger=logger + ) + return utils.results_to_csv(res, to_csv, logger=logger) + if "SELECT" in stmt: + return utils.convert_query_result( + self._cur, dictionary, dataframe, dataframe_headers, logger + ) + return None + except self._Exception as exception: + utils.lg(logger, f"Something went wrong: {exception}", "ERROR") + raise exception + + def close(self): + """Close the db connection and cursor.""" + self._cur.close() + self._conn.close() + + def execute_procedure( + self, + schema_name: str, + procedure_name: str, + package_name=None, + args=None, + logger=None, + send_email=False, + email_info=None, + error_email_to=None, + ): + """ + Executes a procedure with the given arguments or with none if not + provided + ---------- + Parameters + ---------- + schema_name: str + schema name + procedure_name: str + procedure name + package_name: str, optional + package name, by default None + args: str | List[str], optional + arguments to pass to the procedure, by default None + logger: Logger, optional + logger object, by default None + send_email: bool, optional + whether or not to send an email, by default False + email_info: dict, optional + dictionary containing email information + email_info = { + smtp_info: { + host: smtp host, + port: smtp port, + }, + email_to: email to send to, + email_from: email to send from, + error_email_to: email to send error to, + subject: email subject, + message_body: email message body + attatchments: path to attachment(s) + } + """ + stmt = None + if package_name is None and args is None: + stmt = f"CALL {schema_name}.{procedure_name}()" + elif package_name is None and args is not None: + if isinstance(args, (list, tuple)): + stmt = f"CALL {schema_name}.{procedure_name}({','.join(args)})" + else: + stmt = f"CALL {schema_name}.{procedure_name}({args})" + elif package_name is not None and args is None: + stmt = f"CALL {schema_name}.{package_name}.{procedure_name}()" + elif package_name is not None and args is not None: + if isinstance(args, (list, tuple)): + if len(args) > 1: + stmt = f"CALL {schema_name}.{package_name}.{procedure_name}({','.join(utils.surround_with_quotes(args))})" + else: + stmt = f"CALL {schema_name}.{package_name}.{procedure_name}({utils.surround_with_quotes(args[0])})" + + else: + stmt = f"CALL {schema_name}.{package_name}.{procedure_name}({args})" + if stmt is None: + raise Exception("No procedure name or args provided") + try: + utils.lg(logger, f"Executing procedure: {stmt}") + self._cur.execute(stmt) + self._conn.commit() + except self._Exception as e: + utils.lg(logger, f"Something went wrong executing the procedure: {e}") + if email_info is not None: + temp_email_info = deepcopy(email_info) + temp_email_info[ + "message_body" + ] = f"Something went wrong executing the procedure: {e}" + temp_email_info[ + "subject" + ] = f"Error executing procedure: {procedure_name}" + if ( + "error_email_to" in temp_email_info + and temp_email_info["error_email_to"] is not None + ): + logger.info( + f"Sending error email to: {temp_email_info['error_email_to']}" + ) + temp_email_info["email_to"] = temp_email_info["error_email_to"] + try: + email( + **temp_email_info, + logger=logger, + ) + except (MessageError, SMTPException, TimeoutError) as e: + if logger is not None: + logger.error(f"Error sending email: {e}") + raise e + raise e + utils.lg(logger, f"{stmt} executed successfully") + + if send_email and email_info is None: + utils.lg(logger, "No email info provided") + elif send_email: + try: + if logger is not None: + logger.debug(f"Email info: {json.dumps(email_info, indent=4)}") + email(**email_info, logger=logger) + except (MessageError, SMTPException, TimeoutError) as e: + raise e + + +class MysqlDB(DB_WRAPPER): + """Mysql Specific Functions""" + + def __init__(self, info): + """MySQL Connection Wrapper""" + try: + DB_WRAPPER.__init__( + self, info, mysql.connector.connect, mysql.connector.Error + ) + self._conn = self.get_connection() + self._cur = self.get_cursor() + except mysql.connector.Error as exception: + raise exception + + def get_curdate(self): + """Returns CURDATE() from MySQL.""" + return self.query("SELECT CURDATE()") + + def get_timestamp(self): + """Returns CURRENT_TIMESTAMP from MySQL.""" + return self.query("SELECT CURRENT_TIMESTAMP()") + + def table_exists(self, schema: str, table: str): + stmt = f""" + SELECT COUNT(*) from information_schema.TABLES + WHERE TABLE_SCHEMA = '{schema}' and TABLE_NAME = '{table}' + """ + return self.query(stmt) != 0 + + def plsql( + self, + in_plsql: str, + in_name: None, + logger=None, + send_email=False, + email_info=None, + error_email_to=None, + ): + raise NotImplementedError("MySQL does not support PLSQL") + + +class SnowflakeWrapper(DB_WRAPPER): + """Snowflake Specific Functions""" + + def __init__(self, info: DB_INFO): + DB_WRAPPER.__init__( + self, + info, + snowflake.connector.connect, + snowflake.connector.errors.ProgrammingError, + ) + + def plsql( + self, + in_plsql: str, + in_name: None, + logger=None, + send_email=False, + email_info=None, + error_email_to=None, + ): + raise NotImplementedError("Snowflake does not support PLSQL") + + +class OracleWrapper(DB_WRAPPER): + """Oracle specific functions.""" + + def __init__(self, info: DB_INFO): + try: + DB_WRAPPER.__init__(self, info, cx_Oracle.connect, cx_Oracle.Error) + self._db_info = info + except cx_Oracle.Error as e: + raise e + + def get_incoming_me(self): + """Returns the ME_INCOMING table from Oracle.""" + return self.query( + "SELECT POST_DAY_END.PDE_COMMONS_PKG.GET_INCOMING_DAYBREAK_ME_NAME FROM DUAL" + ) + + def get_daybreak_me_target_name(self): + """Returns the ME_TARGET table from Oracle.""" + return self.query( + "SELECT POST_DAY_END.PDE_COMMONS_PKG.GET_DAYBREAK_ME_TARGET_NAME FROM DUAL" + ) + + def plsql( + self, + in_plsql: str, + in_name=None, + logger=None, + send_email=False, + email_info=None, + error_email_to=None, + ): + """ + Executes PL/SQL block and optionally sends email. + ---------- + Parameters + ---------- + in_plsql : str + PL/SQL block + in_name : str, optional + Name of PL/SQL block + logger : logging.Logger, optional + Logger object + send_email : bool, optional + Whether or not to send email + email_info : dict, optional + Dictionary containing email information + email_info = { + smtp_info: { + host: smtp host, + port: smtp port, + }, + email_to: email to send to, + email_from: email to send from, + error_email_to: email to send error to (defaults to email_to), + subject: email subject, + message_body: email message body + attatchments: path to attachment(s) + } + """ + try: + if not in_plsql: + raise ValueError("No PL/SQL") + utils.lg(logger, "Executing PL/SQL block") + utils.lg(logger, f"{in_plsql}", "DEBUG") + self._cur.execute(in_plsql) + self._conn.commit() + except self._Exception as exception: + if email_info is not None: + temp_email_info = deepcopy(email_info) + temp_email_info[ + "message_body" + ] = f"PL/SQL failed to execute. {exception}" + temp_email_info["subject"] = ( + f"{in_name} failed to execute" + if in_name is not None + else "PLSQL failed to execute" + ) + if ( + "error_email_to" in temp_email_info + and temp_email_info["error_email_to"] is not None + ): + temp_email_info["email_to"] = temp_email_info["error_email_to"] + utils.lg( + logger, f"Sending error email to {temp_email_info['email_to']}" + ) + try: + email( + **temp_email_info, + logger=logger, + ) + except (MessageError, SMTPException, TimeoutError) as email_exception: + utils.lg( + logger, + f"Email failed to send: {email_exception}", + "ERROR", + ) + raise email_exception + raise exception + if send_email and email_info is None: + utils.lg(logger, "No email info provided") + elif send_email: + try: + utils.lg(logger, f"Email info: {email_info}", "DEBUG") + if logger is not None: + utils.lg( + logger, + f"Email info: {json.dumps(email_info, indent=4)}", + "DEBUG", + ) + email(**email_info, logger=logger) + except (MessageError, SMTPException, TimeoutError) as email_exception: + utils.lg(logger, f"Email failed to send: {email_exception}", "ERROR") + raise email_exception + utils.lg(logger, "plsql executed successfully") diff --git a/pydb/src/pydb/factory/__init__.py b/pydb/src/pydb/factory/__init__.py new file mode 100644 index 0000000..cb597ae --- /dev/null +++ b/pydb/src/pydb/factory/__init__.py @@ -0,0 +1 @@ +from pydb.factory.db_factory import db_factory diff --git a/pydb/src/pydb/factory/db_factory.py b/pydb/src/pydb/factory/db_factory.py new file mode 100644 index 0000000..f3f3062 --- /dev/null +++ b/pydb/src/pydb/factory/db_factory.py @@ -0,0 +1,31 @@ +"""DB Factory""" +from typing import Dict, Union + +import pydb + + +def db_factory(db_info: Union[Dict[str, int], Dict[str, str]], db_type: str): + """ + Returns a database object based on the database type. + + Parameters + ---------- + db_info: dict + Dictionary containing database connection information + db_type: str + Database type (mysql, oracle, snowflake) + + Returns + ------- + Database object + """ + db_type = db_type.strip().lower() + if db_type == "mysql": + return pydb.db.MysqlDB(db_info) + if db_type == "snowflake": + return pydb.db.SnowflakeWrapper(db_info) + if db_type in ("oracle", "prepdb", "bengal", "livdb", "slivdb"): + return pydb.db.OracleWrapper(db_info) + print("ERROR", db_type, "not valid") + print("Valid types: [ mysql | oracle | snowflake ]") + raise ValueError("Invalid database type") diff --git a/pydb/src/pydb/managers/__init__.py b/pydb/src/pydb/managers/__init__.py new file mode 100644 index 0000000..5fd6957 --- /dev/null +++ b/pydb/src/pydb/managers/__init__.py @@ -0,0 +1 @@ +from pydb.managers.database_manager import DatabaseManager diff --git a/pydb/src/pydb/managers/database_manager.py b/pydb/src/pydb/managers/database_manager.py new file mode 100644 index 0000000..6782f93 --- /dev/null +++ b/pydb/src/pydb/managers/database_manager.py @@ -0,0 +1,21 @@ +"""Database contest manager""" +from pydb.factory.db_factory import db_factory + + +class DatabaseManager(object): + """Context Manager for Database Connection + + Args: + db_info (DB_INFO): Database connection info + db_type (str): Database type + + """ + + def __init__(self, db_info, db_type): + self.db = db_factory(db_info, db_type) + + def __enter__(self): + return self.db + + def __exit__(self, type, value, traceback): + self.db.close() diff --git a/pydb/src/pydb/utils.py b/pydb/src/pydb/utils.py new file mode 100644 index 0000000..5d521af --- /dev/null +++ b/pydb/src/pydb/utils.py @@ -0,0 +1,275 @@ +"""Utility functions for pydb.""" +from csv import QUOTE_MINIMAL, QUOTE_NONNUMERIC +from logging import Logger as DefaultLogger +from typing import List, Union + +from pandas import DataFrame +from python_logger import Logger +from sqlparse import format + +from pydb.db import QUERY_RESULTS +from pydb.managers.database_manager import DatabaseManager + + +def lg(logger: Union[Logger, DefaultLogger, None], msg: str, level="info"): + level = level.strip().lower() + if logger: + if level == "info": + logger.info(msg) + elif level == "debug": + logger.debug(msg) + elif level == "warning": + logger.warning(msg) + elif level == "error": + logger.error(msg) + elif level == "critical": + logger.critical(msg) + elif level > "info": + print(msg) + + +def pretty_print_sql(sql: str, logger: Union[Logger, DefaultLogger, None] = None): + """ + Pretty print sql query or PL/SQL block. + ---------- + Parameters + ---------- + sql : str + SQL query or PL/SQL block + logger : Union[Logger, DefaultLogger, None], optional + Logger object, by default None + + Returns + ------- + str + Pretty printed SQL query or PL/SQL block + """ + try: + pretty_query = format( + sql, + reindent=True, + reindent_aligned=False, + keyword_case="upper", + identifier_case="upper", + indent_width=4, + wrap_after=120, + truncate_strings=120, + use_space_around_operators=True, + # output_format="python", + ) + lg(logger, pretty_query, "debug") + return pretty_query + except Exception as e: + lg(logger, "Failed to pretty print SQL", "ERROR") + raise e + + +def results_to_csv( + query_results: QUERY_RESULTS, + out_path: str, + column_headers=None, + logger=None, + quoting=QUOTE_MINIMAL, + quotechar='"', + delimiter=",", + lineterminator="\n", +): + """ + Execute a query and save the result to a csv file. + ---------- + Parameters + ---------- + query_results : QUERY_RESULTS + Query results + out_path : str + Output file path + column_headers : list, optional + Explicitly set column headers, by default None + logger : Union[Logger, DefaultLogger, None], optional + Logger object, by default None + quoting : int, optional + CSV quoting option, by default QUOTE_MINIMAL + quotechar : str, optional + CSV quotechar option, by default '\"' + delimiter : str, optional + CSV delimiter option, by default "," + lineterminator : str, optional + CSV lineterminator option, by default "\\n" + """ + if not isinstance(query_results, DataFrame): + lg(logger, "Converting query results to dataframe", "debug") + try: + query_results = DataFrame(query_results) + lg(logger, "Converted query results to dataframe", "debug") + except Exception as e: + raise e + try: + lg(logger, f"Saving query results to {out_path}", "info") + query_results.to_csv( + out_path, + index=False, + header=column_headers, + quoting=quoting, + quotechar=quotechar, + sep=delimiter, + lineterminator=lineterminator, + ) + lg(logger, f"Wrote query results to {out_path}", "info") + return out_path + except Exception as e: + raise e + + +def query_to_csv( + db_info, + db_type, + column_headers, + out_path, + query, + args=None, + logger=None, + quoting=QUOTE_MINIMAL, + quotechar='"', + delimiter=",", + lineterminator="\n", +): + """ + Execute a query and save the result to a csv file. + ---------- + Parameters + ---------- + db_info : str + Dictionary of database connection information + Example: + ```python + db_info = { + "host": "hostname", + "user": "username", + "password": "password" + } + ``` + db_type : str + Database type (one of "mysql", "oracle", "snowflake") + column_headers : list + List of column headers + out_path : str + Output file path + query : str + SQL query + args : list, optional + List of arguments for query, by default None + logger : Union[Logger, DefaultLogger, None], optional + Logger object, by default None + quoting : int, optional + CSV quoting option, by default QUOTE_MINIMAL + quotechar : str, optional + CSV quotechar option, by default '\"' + delimiter : str, optional + CSV delimiter option, by default "," + lineterminator : str, optional + CSV lineterminator option, by default "\\n" + """ + with DatabaseManager(db_info, db_type) as db: + try: + df: DataFrame = db.execute( + query, + args=args, + dataframe=True, + dataframe_headers=column_headers, + logger=logger, + ) + except Exception as e: + lg(logger, f"Failed to execute query: {query}", "error") + raise e + try: + # if not isinstance(header, (list, tuple)): + # header = [header] + # logger.debug("DF:", df) + # df = df.set_axis(header, axis=1, copy=False) + # logger.debug("DF:", df) + df.to_csv( + out_path, + index=False, + header=column_headers, + quoting=quoting, + quotechar=quotechar, + sep=delimiter, + lineterminator=lineterminator, + ) + except Exception as e: + raise e + return out_path + + +def convert_query_result( + cur, is_dictionary=False, is_dataframe=False, dataframe_headers=None, logger=None +) -> QUERY_RESULTS: + """ + Return the element or List(rows) if multiple results + Convert from List[List[rows]] -> List[rows] or List[result] -> result + + Parameters + ---------- + cur : cursor + Cursor object from database connection + is_dictionary : bool, optional + Convert query result to dictionary, by default False + is_dataframe : bool, optional + Convert query result to dataframe, by default False + dataframe_headers : list, optional + List of column headers for dataframe, by default None + logger : Logger, optional + Logger object, by default None + + Returns + ------- + QUERY_RESULTS + Query result + """ + if is_dictionary and is_dataframe: + raise Exception("Cannot be both dictionary and dataframe") + try: + res = cur.fetchall() + except Exception as e: + lg(logger, "Failed to fetch results", "ERROR") + raise e + if is_dictionary and cur.description is not None: + return dict_factory(res, cur.description) + if is_dataframe: + if len(res) > 0 and dataframe_headers is not None: + out = DataFrame(res, columns=dataframe_headers) + elif len(res) > 0: + out = DataFrame(res) + else: + out = None + return out + if res is None or (isinstance(res, list) and len(res) == 0): + return None + if len(res) == 1: + if len(res[0]) == 1: + return res[0][0] + return res[0] + return [i[0] if len(i) == 1 else list(i) for i in res] + + +def surround_with_quotes(s) -> Union[List[str], str]: + """Surround each element in list with quotes.""" + if isinstance(s, (list, tuple)): + return [f"'{i}'" for i in s] + return f"'{s}'" + + +def dict_factory(rows, description): + """Converts query result to a list of lists or list of dictionaries.""" + if rows is None: + return None + try: + desc = description + if desc is None: + return None + tdict = [ + dict(zip([col[0].strip().upper() for col in desc], row)) for row in rows + ] + return tdict if len(tdict) > 0 else None + except Exception as e: + raise e