From: Shane Jaroch Date: Mon, 11 Jul 2022 15:46:14 +0000 (-0400) Subject: Chore/python3.4 (#2) X-Git-Tag: v0.2.3~4 X-Git-Url: https://git.nutra.tk/v2?a=commitdiff_plain;h=7a598a67707ac502dbce6565a3564c1cccb564ff;p=nutratech%2Fcli.git Chore/python3.4 (#2) --- diff --git a/ntclient/__main__.py b/ntclient/__main__.py index e500e9c..8df1b2e 100644 --- a/ntclient/__main__.py +++ b/ntclient/__main__.py @@ -27,7 +27,6 @@ along with this program. If not, see . import argparse import sys import time -from typing import Sequence from urllib.error import HTTPError, URLError import argcomplete @@ -76,8 +75,12 @@ def build_argparser() -> argparse.ArgumentParser: return arg_parser -def main(args: Sequence[str] = None) -> int: - """Main method for CLI""" +def main(args: list = None) -> int: + """ + Main method for CLI + + @param args: List[str] | None + """ start_time = time.time() arg_parser = build_argparser() diff --git a/ntclient/persistence/sql/__init__.py b/ntclient/persistence/sql/__init__.py index 58d6f4e..807142d 100644 --- a/ntclient/persistence/sql/__init__.py +++ b/ntclient/persistence/sql/__init__.py @@ -1,23 +1,26 @@ """Main SQL persistence module, need to rethink circular imports and shared code""" import sqlite3 -from typing import Union # FIXME: maybe just use separate methods for calls with vs. without headers # avoid the mypy headaches, and the liberal comments # type: ignore -def sql_entries(sql_result: sqlite3.Cursor, headers=False) -> Union[list, tuple]: +def sql_entries(sql_result: sqlite3.Cursor) -> list: """Formats and returns a `sql_result()` for console digestion and output""" # TODO: return object: metadata, command, status, errors, etc? rows = sql_result.fetchall() - if headers: - headers = [x[0] for x in sql_result.description] - return headers, rows - return rows +def sql_entries_headers(sql_result: sqlite3.Cursor) -> tuple: + """Formats and returns a `sql_result()` for console digestion and output""" + rows = sql_result.fetchall() + headers = [x[0] for x in sql_result.description] + + return headers, rows + + def version(con: sqlite3.Connection) -> str: """Gets the latest entry from version table""" @@ -38,13 +41,11 @@ def close_con_and_cur( con.close() -def _sql( - con: sqlite3.Connection, - query: str, - db_name: str, - values: Union[tuple, list] = None, - headers=False, -) -> Union[list, tuple]: +def _prep_query( + con: sqlite3.Connection, query: str, db_name: str, values=None +) -> tuple: + """@param values: tuple | list | None""" + from ntclient import DEBUG # pylint: disable=import-outside-toplevel cur = con.cursor() @@ -52,19 +53,51 @@ def _sql( if DEBUG: print("%s.sqlite3: %s" % (db_name, query)) if values: - # TODO: better debug logging, more "control-findable", distinguish from most prints() + # TODO: better debug logging, more "control-findable", + # distinguish from most prints() print(values) # TODO: separate `entry` & `entries` entity for single vs. bulk insert? if values: if isinstance(values, list): - rows = cur.executemany(query, values) + result = cur.executemany(query, values) else: # tuple - rows = cur.execute(query, values) + result = cur.execute(query, values) else: - rows = cur.execute(query) + result = cur.execute(query) + + return cur, result + + +def _sql( + con: sqlite3.Connection, + query: str, + db_name: str, + values=None, +) -> list: + """@param values: tuple | list | None""" + + cur, result = _prep_query(con, query, db_name, values) + + # TODO: print " SELECTED", or other info + # BASED ON command SELECT/INSERT/DELETE/UPDATE + result = sql_entries(cur) + + close_con_and_cur(con, cur) + return result + + +def _sql_headers( + con: sqlite3.Connection, + query: str, + db_name: str, + values=None, +) -> tuple: + """@param values: tuple | list | None""" + + cur, result = _prep_query(con, query, db_name, values) + + result = sql_entries_headers(result) - # TODO: print " SELECTED", or other info BASED ON command SELECT/INSERT/DELETE/UPDATE - result = sql_entries(rows, headers=headers) close_con_and_cur(con, cur) return result diff --git a/ntclient/persistence/sql/nt/__init__.py b/ntclient/persistence/sql/nt/__init__.py index 8dee265..cc1e750 100644 --- a/ntclient/persistence/sql/nt/__init__.py +++ b/ntclient/persistence/sql/nt/__init__.py @@ -3,11 +3,11 @@ import os import sqlite3 from ntclient import NT_DB_NAME, NUTRA_DIR, __db_target_nt__ -from ntclient.persistence.sql import _sql, version +from ntclient.persistence.sql import _sql, _sql_headers, version from ntclient.utils.exceptions import SqlConnectError, SqlInvalidVersionError -def nt_sqlite_connect(version_check=True): +def nt_sqlite_connect(version_check=True) -> sqlite3.Connection: """Connects to the nt.sqlite3 file, or throws an exception""" db_path = os.path.join(NUTRA_DIR, NT_DB_NAME) if os.path.isfile(db_path): @@ -29,13 +29,19 @@ def nt_sqlite_connect(version_check=True): raise SqlConnectError("ERROR: nt database doesn't exist, please run `nutra init`") -def nt_ver(): +def nt_ver() -> str: """Gets version string for nt.sqlite3 database""" con = nt_sqlite_connect(version_check=False) return version(con) -def sql(query, values=None, headers=False): +def sql(query, values=None) -> list: """Executes a SQL command to nt.sqlite3""" con = nt_sqlite_connect() - return _sql(con, query, db_name="nt", values=values, headers=headers) + return _sql(con, query, db_name="nt", values=values) + + +def sql_headers(query, values=None) -> tuple: + """Executes a SQL command to nt.sqlite3""" + con = nt_sqlite_connect() + return _sql_headers(con, query, db_name="nt", values=values) diff --git a/ntclient/persistence/sql/nt/funcs.py b/ntclient/persistence/sql/nt/funcs.py index 7c042af..bee8fc6 100644 --- a/ntclient/persistence/sql/nt/funcs.py +++ b/ntclient/persistence/sql/nt/funcs.py @@ -1,5 +1,5 @@ """nt.sqlite3 functions module""" -from ntclient.persistence.sql.nt import sql +from ntclient.persistence.sql.nt import sql, sql_headers def sql_nt_next_index(table=None): @@ -33,7 +33,7 @@ FROM GROUP BY id; """ - return sql(query, headers=True) + return sql_headers(query) def sql_analyze_recipe(recipe_id): diff --git a/ntclient/persistence/sql/usda/__init__.py b/ntclient/persistence/sql/usda/__init__.py index a2f27aa..4c3df95 100644 --- a/ntclient/persistence/sql/usda/__init__.py +++ b/ntclient/persistence/sql/usda/__init__.py @@ -3,10 +3,9 @@ import os import sqlite3 import tarfile import urllib.request -from typing import Union from ntclient import NUTRA_DIR, USDA_DB_NAME, __db_target_usda__ -from ntclient.persistence.sql import _sql, version +from ntclient.persistence.sql import _sql, _sql_headers, version from ntclient.utils.exceptions import SqlConnectError, SqlInvalidVersionError @@ -92,10 +91,19 @@ def usda_ver() -> str: return version(con) -def sql(query, values=None, headers=False, version_check=True) -> Union[list, tuple]: +def sql(query, values=None, version_check=True) -> list: """Executes a SQL command to usda.sqlite3""" con = usda_sqlite_connect(version_check=version_check) # TODO: support argument: _sql(..., params=params, ...) - return _sql(con, query, db_name="usda", values=values, headers=headers) + return _sql(con, query, db_name="usda", values=values) + + +def sql_headers(query, values=None, version_check=True) -> tuple: + """Executes a SQL command to usda.sqlite3 [WITH HEADERS]""" + + con = usda_sqlite_connect(version_check=version_check) + + # TODO: support argument: _sql(..., params=params, ...) + return _sql_headers(con, query, db_name="usda", values=values) diff --git a/ntclient/persistence/sql/usda/funcs.py b/ntclient/persistence/sql/usda/funcs.py index ff21588..00bdd61 100644 --- a/ntclient/persistence/sql/usda/funcs.py +++ b/ntclient/persistence/sql/usda/funcs.py @@ -1,5 +1,5 @@ """usda.sqlite functions module""" -from ntclient.persistence.sql.usda import sql +from ntclient.persistence.sql.usda import sql, sql_headers from ntclient.utils import NUTR_ID_KCAL @@ -25,7 +25,7 @@ def sql_food_details(food_ids=None) -> list: food_ids = ",".join(str(x) for x in set(food_ids)) query = query % food_ids - return sql(query) # type: ignore + return sql(query) def sql_nutrients_overview() -> dict: @@ -40,7 +40,7 @@ def sql_nutrients_details() -> tuple: """Shows nutrients 'details'""" query = "SELECT * FROM nutrients_overview;" - return sql(query, headers=True) # type: ignore + return sql_headers(query) def sql_servings(food_ids) -> list: @@ -59,7 +59,7 @@ WHERE serv.food_id IN (%s); """ food_ids = ",".join(str(x) for x in set(food_ids)) - return sql(query % food_ids) # type: ignore + return sql(query % food_ids) def sql_analyze_foods(food_ids) -> list: @@ -77,7 +77,7 @@ WHERE """ # TODO: parameterized queries food_ids = ",".join(str(x) for x in set(food_ids)) - return sql(query % food_ids) # type: ignore + return sql(query % food_ids) ################################################################################ @@ -100,7 +100,7 @@ ORDER BY food_id; """ - return sql(query % (NUTR_ID_KCAL, nutrient_id)) # type: ignore + return sql(query % (NUTR_ID_KCAL, nutrient_id)) def sql_sort_foods(nutr_id) -> list: @@ -126,7 +126,7 @@ ORDER BY nut_data.nutr_val DESC; """ - return sql(query % nutr_id) # type: ignore + return sql(query % nutr_id) def sql_sort_foods_by_kcal(nutr_id) -> list: @@ -155,4 +155,4 @@ ORDER BY (nut_data.nutr_val / kcal.nutr_val) DESC; """ - return sql(query % nutr_id) # type: ignore + return sql(query % nutr_id)