Terarea  2
The automation project
Loading...
Searching...
No Matches
sql_sanitisation_functions.py
Go to the documentation of this file.
1"""
2 File in charge of cleaning and sanitising sql queries before they are submitted to the database.
3"""
4
5from typing import List, Dict, Any, Union
6
7from display_tty import Disp, TOML_CONF, SAVE_TO_FILE, FILE_NAME
8
9from . import sql_constants as SCONST
10from .sql_time_manipulation import SQLTimeManipulation
11
12
14 """_summary_
15 """
16
17 def __init__(self, success: int = 0, error: int = 84, debug: bool = False) -> None:
18 """_summary_
19 This is the class that contains functions in charge of sanitising sql queries.
20
21 Args:
22 debug (bool, optional): _description_. Defaults to False.: enable debug mode
23 """
24 self.error: int = error
25 self.debug: bool = debug
26 self.success: int = success
27 # --------------------------- logger section ---------------------------
28 self.disp: Disp = Disp(
29 TOML_CONF,
30 SAVE_TO_FILE,
31 FILE_NAME,
32 debug=self.debug,
33 logger=self.__class__.__name__
34 )
35 # ----------------- Database risky keyword sanitising -----------------
36 self.risky_keywordsrisky_keywords: List[str] = SCONST.RISKY_KEYWORDS
37 self.keyword_logic_gates: List[str] = SCONST.KEYWORD_LOGIC_GATES
38 # ---------------------- Time manipulation class ----------------------
39 self.sql_time_manipulation: SQLTimeManipulation = SQLTimeManipulation(
40 self.debug
41 )
42
43 def protect_sql_cell(self, cell: str) -> str:
44 """_summary_
45 This is a function in charge of cleaning by nullifying (escaping) characters that could cause the sql command to break.
46
47 Args:
48 cells (str): _description_: The cell to be checked
49
50 Returns:
51 str: _description_: A (hopfully) clean string.
52 """
53 result = ""
54 for char in cell:
55 if char in ("'", '"', "\\", '\0', "\r"):
56 self.disp.log_info(
57 f"Escaped character '{char}' in '{cell}'.",
58 "_protect_sql_cell"
59 )
60 result += "\\"+char
61 else:
62 result += char
63 return result
64
65 def escape_risky_column_names(self, columns: Union[List[str], str]) -> Union[List[str], str]:
66 """_summary_
67 Escape the risky column names.
68
69 Args:
70 columns (List[str]): _description_
71
72 Returns:
73 List[str]: _description_
74 """
75 title = "_escape_risky_column_names"
76 self.disp.log_debug("Escaping risky column names.", title)
77 if isinstance(columns, str):
78 data = [columns]
79 else:
80 data = columns
81 for index, item in enumerate(data):
82 if "=" in item:
83 key, value = item.split("=", maxsplit=1)
84 self.disp.log_debug(f"key = {key}, value = {value}", title)
85 if key.lower() in self.risky_keywordsrisky_keywords:
86 self.disp.log_warning(
87 f"Escaping risky column name '{key}'.",
88 "_escape_risky_column_names"
89 )
90 data[index] = f"`{key}`={value}"
91 elif item.lower() in self.risky_keywordsrisky_keywords:
92 self.disp.log_warning(
93 f"Escaping risky column name '{item}'.",
94 "_escape_risky_column_names"
95 )
96 data[index] = f"`{item}`"
97 else:
98 continue
99 self.disp.log_debug("Escaped risky column names.", title)
100 if isinstance(columns, str):
101 return data[0]
102 return columns
103
104 def _protect_value(self, value: str) -> str:
105 """Ensures the value is safely passed as a string in an SQL query.
106 It wraps the value in single quotes and escapes any single quotes within the value.
107
108 Args:
109 value (str): The value that needs to be protected.
110
111 Returns:
112 str: The value with protection applied, safe for SQL queries.
113 """
114 title = "_protect_value"
115 self.disp.log_debug(f"protecting value: {value}", title)
116 if value is None:
117 self.disp.log_debug("Value is none, thus returning NULL", title)
118 return "NULL"
119
120 if isinstance(value, str) is False:
121 self.disp.log_debug("Value is not a string, converting", title)
122 value = str(value)
123
124 if len(value) == 0:
125 self.disp.log_debug("Value is empty, returning ''", title)
126 return "''"
127
128 if value[0] == '`' and value[-1] == '`':
129 self.disp.log_debug(
130 "string has special backtics, skipping.", title
131 )
132 return value
133
134 if value[0] == "'":
135 self.disp.log_debug(
136 "Value already has a single quote at the start, removing", title
137 )
138 value = value[1:]
139 if value[-1] == "'":
140 self.disp.log_debug(
141 "Value already has a single quote at the end, removing", title
142 )
143 value = value[:-1]
144
145 self.disp.log_debug(
146 f"Value before quote escaping: {value}", title
147 )
148 protected_value = value.replace("'", "''")
149 self.disp.log_debug(
150 f"Value after quote escaping: {protected_value}", title
151 )
152
153 protected_value = f"'{protected_value}'"
154 self.disp.log_debug(
155 f"Value after being converted to a string: {protected_value}.",
156 title
157 )
158 return protected_value
159
160 def escape_risky_column_names_where_mode(self, columns: Union[List[str], str]) -> Union[List[str], str]:
161 """
162 Escape the risky column names in where mode, except for those in keyword_logic_gates.
163
164 Args:
165 columns (Union[str, List[str]]): Column names to be processed.
166
167 Returns:
168 Union[List[str], str]: Processed column names with risky ones escaped.
169 """
170 title = "_escape_risky_column_names_where_mode"
171 self.disp.log_debug(
172 "Escaping risky column names in where mode.", title
173 )
174
175 if isinstance(columns, str):
176 data = [columns]
177 else:
178 data = columns
179
180 for index, item in enumerate(data):
181 if "=" in item:
182 key, value = item.split("=", maxsplit=1)
183 self.disp.log_debug(f"key = {key}, value = {value}", title)
184
185 protected_value = self._protect_value(value)
186 if key.lower() not in self.keyword_logic_gates and key.lower() in self.risky_keywordsrisky_keywords:
187 self.disp.log_warning(
188 f"Escaping risky column name '{key}'.", title
189 )
190 data[index] = f"`{key}`={protected_value}"
191 else:
192 data[index] = f"{key}={protected_value}"
193
194 elif item.lower() not in self.keyword_logic_gates and item.lower() in self.risky_keywordsrisky_keywords:
195 self.disp.log_warning(
196 f"Escaping risky column name '{item}'.",
197 title
198 )
199 protected_value = self._protect_value(item)
200 data[index] = protected_value
201
202 self.disp.log_debug("Escaped risky column names in where mode.", title)
203
204 if isinstance(columns, str):
205 return data[0]
206 return data
207
208 def check_sql_cell(self, cell: str) -> str:
209 """_summary_
210 Check if the cell is a string or a number.
211
212 Args:
213 cell (str): _description_
214
215 Returns:
216 str: _description_
217 """
218 if isinstance(cell, (str, float)) is True:
219 cell = str(cell)
220 if isinstance(cell, str) is False:
221 msg = "The expected type of the input is a string,"
222 msg += f"but got {type(cell)}"
223 self.disp.log_error(msg, "_check_sql_cell")
224 return cell
225 cell = self.protect_sql_cell(cell)
226 tmp = cell.lower()
227 if tmp in ("now", "now()"):
228 tmp = self.sql_time_manipulation.get_correct_now_value()
229 elif tmp in ("current_date", "current_date()"):
230 tmp = self.sql_time_manipulation.get_correct_current_date_value()
231 else:
232 tmp = str(cell)
233 if ";base" not in tmp:
234 self.disp.log_debug(f"result = {tmp}", "_check_sql_cell")
235 return f"\"{str(tmp)}\""
236
237 def beautify_table(self, column_names: List[str], table_content: List[List[Any]]) -> Union[List[Dict[str, Any]], int]:
238 """_summary_
239 Convert the table to an easier version for navigating.
240
241 Args:
242 column_names (List[str]): _description_
243 table_content (List[List[Any]]): _description_
244
245 Returns:
246 Union[List[Dict[str, Any]], int]: _description_: the formated content or self.error if an error occured.
247 """
248 data: List[Dict[str, Any]] = []
249 v_index: int = 0
250 if len(column_names) == 0:
251 self.disp.log_error(
252 "There are not provided table column names.",
253 "_beautify_table"
254 )
255 return table_content
256 if len(table_content) == 0:
257 self.disp.log_error(
258 "There is no table content.",
259 "_beautify_table"
260 )
261 return self.error
262 column_length = len(column_names)
263 for i in table_content:
264 cell_length = len(i)
265 if cell_length != column_length:
266 self.disp.log_warning(
267 "Table content and column lengths do not correspond.",
268 "_beautify_table"
269 )
270 data.append({})
271 for index, items in enumerate(column_names):
272 if index == cell_length:
273 self.disp.log_warning(
274 "Skipping the rest of the tuple because it is shorter than the column names.",
275 "_beautify_table"
276 )
277 break
278 data[v_index][items[0]] = i[index]
279 v_index += 1
280 self.disp.log_debug(f"beautified_table = {data}", "_beautify_table")
281 return data
282
283 def compile_update_line(self, line: List, column: List, column_length: int) -> str:
284 """_summary_
285 Compile the line required for an sql update to work.
286
287 Args:
288 line (List): _description_
289 column (List): _description_
290 column_length (int): _description_
291
292 Returns:
293 str: _description_
294 """
295 title = "compile_update_line"
296 final_line = ""
297 self.disp.log_debug("Compiling update line.", title)
298 for i in range(0, column_length):
299 cell_content = self.check_sql_cell(line[i])
300 final_line += f"{column[i]} = {cell_content}"
301 if i < column_length - 1:
302 final_line += ", "
303 if i == column_length:
304 break
305 self.disp.log_debug(f"line = {final_line}", title)
306 return final_line
307
308 def process_sql_line(self, line: List[str], column: List[str], column_length: int = (-1)) -> str:
309 """_summary_
310 Convert a List of strings to an sql line so that it can be inserted into a table.
311
312 Args:
313 line (List[str]): _description_
314
315 Returns:
316 str: _description_
317 """
318 if column_length == -1:
319 column_length = len(column)
320 line_length = len(line)
321
322 line_final = "("
323 if self.debug is True and ";base" not in str(line):
324 msg = f"line = {line}"
325 self.disp.log_debug(msg, "_process_sql_line")
326 for i in range(0, column_length):
327 line_final += self.check_sql_cell(line[i])
328 if i < column_length - 1:
329 line_final += ", "
330 if i == column_length:
331 if i < line_length:
332 msg = "The line is longer than the number of columns, truncating."
333 self.disp.log_warning(msg, "_process_sql_line")
334 break
335 line_final += ")"
336 if ";base" not in str(line_final):
337 msg = f"line_final = '{line_final}'"
338 msg += f", type(line_final) = '{type(line_final)}'"
339 self.disp.log_debug(msg, "_process_sql_line")
340 return line_final
None __init__(self, int success=0, int error=84, bool debug=False)
str compile_update_line(self, List line, List column, int column_length)
Union[List[Dict[str, Any]], int] beautify_table(self, List[str] column_names, List[List[Any]] table_content)
str process_sql_line(self, List[str] line, List[str] column, int column_length=(-1))
Union[List[str], str] escape_risky_column_names_where_mode(self, Union[List[str], str] columns)
Union[List[str], str] escape_risky_column_names(self, Union[List[str], str] columns)