From c5e9060069aa3f07faa3dc49f9ce6bc223abbd7c Mon Sep 17 00:00:00 2001 From: Sashankh Chengavalli Kumar Date: Tue, 23 Apr 2024 19:00:05 -0400 Subject: [PATCH 1/3] Implement Reddit as third party datasource --- docs/source/reference/databases/reddit.rst | 48 +++++ evadb/third_party/databases/interface.py | 2 + .../third_party/databases/reddit/__init__.py | 15 ++ .../databases/reddit/reddit_handler.py | 168 ++++++++++++++++++ .../databases/reddit/table_column_info.py | 35 ++++ run_reddit_command.py | 0 script/formatting/spelling.txt | 1 + setup.py | 7 +- .../long/test_reddit_datasource.py | 60 +++++++ test/markers.py | 4 + 10 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 docs/source/reference/databases/reddit.rst create mode 100644 evadb/third_party/databases/reddit/__init__.py create mode 100644 evadb/third_party/databases/reddit/reddit_handler.py create mode 100644 evadb/third_party/databases/reddit/table_column_info.py create mode 100644 run_reddit_command.py create mode 100644 test/integration_tests/long/test_reddit_datasource.py diff --git a/docs/source/reference/databases/reddit.rst b/docs/source/reference/databases/reddit.rst new file mode 100644 index 000000000..c25115e91 --- /dev/null +++ b/docs/source/reference/databases/reddit.rst @@ -0,0 +1,48 @@ +Reddit +========== + +The connection to Reddit is based on the `praw `_ library. + +Dependency +---------- + +* praw + + +Parameters +---------- + +Required: + +* ``subreddit`` is the name of the subreddit from which the data is fetched. +* ``clientId`` is the unique identifier issued to the client when creating credentials on Reddit. Refer to the [First Steps](https://github.com/reddit-archive/reddit/wiki/OAuth2-Quick-Start-Example#first-steps) guide for more details on how to get this and the next two parameters. +* ``clientSecret`` is the secret key obtained when credentials are created that is used for authentication and authorization. +* ``userAgent`` is a string of your choosing that explains your use of the the Reddit API. More details are available in the guide linked above. + +Optional: + + +Create Connection +----------------- + +.. code-block:: text + + CREATE DATABASE reddit_data WITH ENGINE = 'reddit', PARAMETERS = { + "subreddit": "AskReddit", + "client_id": "abcd", + "clientSecret": "abcd1234", + "userAgent": "Eva DB Staging Build" + }; + +Supported Tables +---------------- + +* ``submissions``: Lists top submissions in the given subreddit. Check `databases/reddit/table_column_info.py` for all the available columns in the table. + +.. code-block:: sql + + SELECT * FROM hackernews_data.search_results LIMIT 3; + +.. note:: + + Looking for another table from Hackernews? Please raise a `Feature Request `_. diff --git a/evadb/third_party/databases/interface.py b/evadb/third_party/databases/interface.py index cacb4110f..3743ebee7 100644 --- a/evadb/third_party/databases/interface.py +++ b/evadb/third_party/databases/interface.py @@ -52,6 +52,8 @@ def _get_database_handler(engine: str, **kwargs): return mod.HackernewsSearchHandler(engine, **kwargs) elif engine == "slack": return mod.SlackHandler(engine, **kwargs) + elif engine == "reddit": + return mod.RedditHandler(engine, **kwargs) else: raise NotImplementedError(f"Engine {engine} is not supported") diff --git a/evadb/third_party/databases/reddit/__init__.py b/evadb/third_party/databases/reddit/__init__.py new file mode 100644 index 000000000..4dc9f9e97 --- /dev/null +++ b/evadb/third_party/databases/reddit/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""third party/applications/reddit""" \ No newline at end of file diff --git a/evadb/third_party/databases/reddit/reddit_handler.py b/evadb/third_party/databases/reddit/reddit_handler.py new file mode 100644 index 000000000..8637d957f --- /dev/null +++ b/evadb/third_party/databases/reddit/reddit_handler.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +from praw import Reddit +from prawcore import ResponseException + +from .table_column_info import SUBMISSION_COLUMNS +from ..types import DBHandler, DBHandlerResponse, DBHandlerStatus + + +class RedditHandler(DBHandler): + def __init__(self, name: str, **kwargs): + super().__init__(name) + self.clientId = kwargs.get("client_id") + self.clientSecret = kwargs.get("clientSecret") + self.userAgent = kwargs.get("userAgent") + self.subreddit = kwargs.get("subreddit") + + def connect(self): + try: + self.client = Reddit( + client_id=self.clientId, + client_secret=self.clientSecret, + user_agent=self.userAgent + ) + return DBHandlerStatus(status=True) + except Exception as e: + return DBHandlerStatus(status=False, error=str(e)) + + @property + def supported_table(self): + def _submission_generator(): + for submission in self.client.subreddit(self.subreddit).hot(): #TODO: REMOVE LIMIT + yield { + property_name: getattr(submission, property_name) + for property_name, _ in SUBMISSION_COLUMNS + } + + mapping = { + "submissions": { + "columns": SUBMISSION_COLUMNS, + "generator": _submission_generator(), + }, + } + return mapping + + def disconnect(self): + """ + No action required to disconnect from Reddit datasource + TODO: Add support for destroying session token if used in other flows + """ + return + #raise NotImplementedError() + + def check_connection(self) -> DBHandlerStatus: + try: + self.client.user.me() + except ResponseException as e: + return DBHandlerStatus(status=False, error=f"Received ResponseException: {e.response}") + return DBHandlerStatus(status=True) + + def get_tables(self) -> DBHandlerResponse: + connection_status = self.check_connection() + if not connection_status.status: + return DBHandlerResponse(data=None, error=str(connection_status)) + + try: + tables_df = pd.DataFrame( + list(self.supported_table.keys()), columns=["table_name"] + ) + return DBHandlerResponse(data=tables_df) + except Exception as e: + return DBHandlerResponse(data=None, error=str(e)) + + def get_columns(self, table_name: str) -> DBHandlerResponse: + columns = self.supported_table[table_name]["columns"] + columns_df = pd.DataFrame(columns, columns=["name", "dtype"]) + return DBHandlerResponse(data=columns_df) + + def select(self, table_name: str) -> DBHandlerResponse: + """ + Returns a generator that yields the data from the given table. + Args: + table_name (str): name of the table whose data is to be retrieved. + Returns: + DBHandlerResponse + """ + if not self.client: + return DBHandlerResponse(data=None, error="Not connected to the database.") + try: + if table_name not in self.supported_table: + return DBHandlerResponse( + data=None, + error="{} is not supported or does not exist.".format(table_name), + ) + # TODO: Projection column trimming optimization opportunity + return DBHandlerResponse( + data=None, + data_generator=self.supported_table[table_name]["generator"], + ) + except Exception as e: + return DBHandlerResponse(data=None, error=str(e)) + + # def post_message(self, message) -> DBHandlerResponse: + # try: + # response = self.client.chat_postMessage(channel=self.channel, text=message) + # return DBHandlerResponse(data=response["message"]["text"]) + # except SlackApiError as e: + # assert e.response["ok"] is False + # assert e.response["error"] + # return DBHandlerResponse(data=None, error=e.response["error"]) + # + # def _convert_json_response_to_DataFrame(self, json_response): + # messages = json_response["messages"] + # columns = ["text", "ts", "user"] + # data_df = pd.DataFrame(columns=columns) + # for message in messages: + # if message["text"] and message["ts"] and message["user"]: + # data_df.loc[len(data_df.index)] = [ + # message["text"], + # message["ts"], + # message["user"], + # ] + # return data_df + # + # def get_messages(self) -> DBHandlerResponse: + # try: + # channels = self.client.conversations_list( + # types="public_channel,private_channel" + # )["channels"] + # channel_ids = {c["name"]: c["id"] for c in channels} + # response = self.client.conversations_history( + # channel=channel_ids[self.channel_name] + # ) + # data_df = self._convert_json_response_to_DataFrame(response) + # return data_df + # + # except SlackApiError as e: + # assert e.response["ok"] is False + # assert e.response["error"] + # return DBHandlerResponse(data=None, error=e.response["error"]) + # + # def del_message(self, timestamp) -> DBHandlerResponse: + # try: + # self.client.chat_delete(channel=self.channel, ts=timestamp) + # except SlackApiError as e: + # assert e.response["ok"] is False + # assert e.response["error"] + # return DBHandlerResponse(data=None, error=e.response["error"]) + + # def execute_native_query(self, query_string: str) -> DBHandlerResponse: + # """ + # TODO: integrate code for executing query on Reddit + # """ + # raise NotImplementedError() diff --git a/evadb/third_party/databases/reddit/table_column_info.py b/evadb/third_party/databases/reddit/table_column_info.py new file mode 100644 index 000000000..b06f12638 --- /dev/null +++ b/evadb/third_party/databases/reddit/table_column_info.py @@ -0,0 +1,35 @@ +from typing import Union + +# SUBMISSION_COLUMNS = [ +# ["author", str], +# ["author_flair_text", Union[str, None]], +# ["clicked", bool], +# ["created_utc", str], +# ["distinguished", bool], +# ["edited", bool], +# ["id", str], +# ["is_original_content", bool], +# ["is_self", bool], +# ["link_flair_template_id", "str"], +# ["link_flair_text", Union[str, None]], +# ["locked", bool], +# ["name", str], +# ["num_comments", int], +# ["over_18", bool], +# ["permalink", str], +# ["saved", bool], +# ["score", float], +# ["selftext", str], +# ["spoiler", bool], +# ["stickied", bool], +# ["title", str], +# ["upvote_ratio", float], +# ["url", str] +# ] + + +SUBMISSION_COLUMNS = [ + ["author", str], + ["score", float], + ["title", str], +] diff --git a/run_reddit_command.py b/run_reddit_command.py new file mode 100644 index 000000000..e69de29bb diff --git a/script/formatting/spelling.txt b/script/formatting/spelling.txt index 1dd5566ca..0444f935b 100644 --- a/script/formatting/spelling.txt +++ b/script/formatting/spelling.txt @@ -695,6 +695,7 @@ PlanOprType Popen PostgresHandler PostgresNativeStorageEngineTest +praw PredicateExecutor PredicatePlan PredictEmployee diff --git a/setup.py b/setup.py index e3d211ece..be7bdbdf9 100644 --- a/setup.py +++ b/setup.py @@ -138,6 +138,10 @@ def read(path, encoding="utf-8"): "replicate" ] +reddit_libs = [ + "praw" +] + ### NEEDED FOR DEVELOPER TESTING ONLY dev_libs = [ @@ -183,8 +187,9 @@ def read(path, encoding="utf-8"): "xgboost": xgboost_libs, "forecasting": forecasting_libs, "hackernews": hackernews_libs, + "reddit": reddit_libs, # everything except ray, qdrant, ludwig and postgres. The first three fail on pyhton 3.11. - "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs + "dev": dev_libs + vision_libs + document_libs + function_libs + notebook_libs + forecasting_libs + sklearn_libs + imagegen_libs + xgboost_libs + reddit_libs } setup( diff --git a/test/integration_tests/long/test_reddit_datasource.py b/test/integration_tests/long/test_reddit_datasource.py new file mode 100644 index 000000000..b3a1c866f --- /dev/null +++ b/test/integration_tests/long/test_reddit_datasource.py @@ -0,0 +1,60 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +from test.markers import reddit_skip_marker +from test.util import get_evadb_for_testing + +import pytest + +from evadb.server.command_handler import execute_query_fetch_all +from evadb.third_party.databases.reddit.table_column_info import SUBMISSION_COLUMNS + + +@pytest.mark.notparallel +class RedditDataSourceTest(unittest.TestCase): + def setUp(self): + self.evadb = get_evadb_for_testing() + # reset the catalog manager before running each test + self.evadb.catalog().reset() + + def tearDown(self): + execute_query_fetch_all(self.evadb, "DROP DATABASE IF EXISTS reddit_data;") + + @reddit_skip_marker + def test_should_run_select_query_on_reddit(self): + # Create database. + params = { + "subreddit": "cricket", + "client_id": 'clientid..', + "client_secret": 'clientsecret..', + "user_agent": 'test script for dev eva' + } + query = f"""CREATE DATABASE reddit_data + WITH ENGINE = "reddit", + PARAMETERS = {params};""" + execute_query_fetch_all(self.evadb, query) + + query = "SELECT * FROM reddit_data.submissions LIMIT 10;" + batch = execute_query_fetch_all(self.evadb, query) + self.assertEqual(len(batch), 10) + expected_column = list( + ["submissions.{}".format(col) for col, _ in SUBMISSION_COLUMNS] + ) + self.assertEqual(batch.columns, expected_column) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/markers.py b/test/markers.py index deefadb29..ad3fa1169 100644 --- a/test/markers.py +++ b/test/markers.py @@ -117,3 +117,7 @@ stable_diffusion_skip_marker = pytest.mark.skipif( is_replicate_available() is False, reason="requires replicate" ) + +reddit_skip_marker = pytest.mark.skip( + reason="requires Reddit secret key" +) \ No newline at end of file From 3cb048ce003af978a5aee2f4a4c1170b0550746c Mon Sep 17 00:00:00 2001 From: Sashankh Chengavalli Kumar Date: Tue, 23 Apr 2024 19:02:32 -0400 Subject: [PATCH 2/3] Enable all columns and remove problematic field --- .../databases/reddit/table_column_info.py | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/evadb/third_party/databases/reddit/table_column_info.py b/evadb/third_party/databases/reddit/table_column_info.py index b06f12638..088406458 100644 --- a/evadb/third_party/databases/reddit/table_column_info.py +++ b/evadb/third_party/databases/reddit/table_column_info.py @@ -1,35 +1,27 @@ from typing import Union -# SUBMISSION_COLUMNS = [ -# ["author", str], -# ["author_flair_text", Union[str, None]], -# ["clicked", bool], -# ["created_utc", str], -# ["distinguished", bool], -# ["edited", bool], -# ["id", str], -# ["is_original_content", bool], -# ["is_self", bool], -# ["link_flair_template_id", "str"], -# ["link_flair_text", Union[str, None]], -# ["locked", bool], -# ["name", str], -# ["num_comments", int], -# ["over_18", bool], -# ["permalink", str], -# ["saved", bool], -# ["score", float], -# ["selftext", str], -# ["spoiler", bool], -# ["stickied", bool], -# ["title", str], -# ["upvote_ratio", float], -# ["url", str] -# ] - - SUBMISSION_COLUMNS = [ ["author", str], + ["author_flair_text", Union[str, None]], + ["clicked", bool], + ["created_utc", str], + ["distinguished", bool], + ["edited", bool], + ["id", str], + ["is_original_content", bool], + ["is_self", bool], + ["link_flair_text", Union[str, None]], + ["locked", bool], + ["name", str], + ["num_comments", int], + ["over_18", bool], + ["permalink", str], + ["saved", bool], ["score", float], + ["selftext", str], + ["spoiler", bool], + ["stickied", bool], ["title", str], -] + ["upvote_ratio", float], + ["url", str] +] \ No newline at end of file From a9450e76bd1f8b0dea716ce3fa7af978c17a7109 Mon Sep 17 00:00:00 2001 From: Sashankh Chengavalli Kumar Date: Tue, 23 Apr 2024 19:08:01 -0400 Subject: [PATCH 3/3] Reformat code to fix formatting errors and remove erroneous TODO in RedditDBHandler --- evadb/third_party/databases/reddit/__init__.py | 2 +- .../databases/reddit/reddit_handler.py | 12 +++++++----- .../databases/reddit/table_column_info.py | 18 ++++++++++++++++-- .../long/test_reddit_datasource.py | 7 +++---- test/markers.py | 4 +--- 5 files changed, 28 insertions(+), 15 deletions(-) diff --git a/evadb/third_party/databases/reddit/__init__.py b/evadb/third_party/databases/reddit/__init__.py index 4dc9f9e97..01c7e9e97 100644 --- a/evadb/third_party/databases/reddit/__init__.py +++ b/evadb/third_party/databases/reddit/__init__.py @@ -12,4 +12,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""third party/applications/reddit""" \ No newline at end of file +"""third party/applications/reddit""" diff --git a/evadb/third_party/databases/reddit/reddit_handler.py b/evadb/third_party/databases/reddit/reddit_handler.py index 8637d957f..3450e3c5b 100644 --- a/evadb/third_party/databases/reddit/reddit_handler.py +++ b/evadb/third_party/databases/reddit/reddit_handler.py @@ -17,8 +17,8 @@ from praw import Reddit from prawcore import ResponseException -from .table_column_info import SUBMISSION_COLUMNS from ..types import DBHandler, DBHandlerResponse, DBHandlerStatus +from .table_column_info import SUBMISSION_COLUMNS class RedditHandler(DBHandler): @@ -34,7 +34,7 @@ def connect(self): self.client = Reddit( client_id=self.clientId, client_secret=self.clientSecret, - user_agent=self.userAgent + user_agent=self.userAgent, ) return DBHandlerStatus(status=True) except Exception as e: @@ -43,7 +43,7 @@ def connect(self): @property def supported_table(self): def _submission_generator(): - for submission in self.client.subreddit(self.subreddit).hot(): #TODO: REMOVE LIMIT + for submission in self.client.subreddit(self.subreddit).hot(): yield { property_name: getattr(submission, property_name) for property_name, _ in SUBMISSION_COLUMNS @@ -63,13 +63,15 @@ def disconnect(self): TODO: Add support for destroying session token if used in other flows """ return - #raise NotImplementedError() + # raise NotImplementedError() def check_connection(self) -> DBHandlerStatus: try: self.client.user.me() except ResponseException as e: - return DBHandlerStatus(status=False, error=f"Received ResponseException: {e.response}") + return DBHandlerStatus( + status=False, error=f"Received ResponseException: {e.response}" + ) return DBHandlerStatus(status=True) def get_tables(self) -> DBHandlerResponse: diff --git a/evadb/third_party/databases/reddit/table_column_info.py b/evadb/third_party/databases/reddit/table_column_info.py index 088406458..6dd819b6e 100644 --- a/evadb/third_party/databases/reddit/table_column_info.py +++ b/evadb/third_party/databases/reddit/table_column_info.py @@ -1,3 +1,17 @@ +# coding=utf-8 +# Copyright 2018-2023 EvaDB +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Union SUBMISSION_COLUMNS = [ @@ -23,5 +37,5 @@ ["stickied", bool], ["title", str], ["upvote_ratio", float], - ["url", str] -] \ No newline at end of file + ["url", str], +] diff --git a/test/integration_tests/long/test_reddit_datasource.py b/test/integration_tests/long/test_reddit_datasource.py index b3a1c866f..4ba41086a 100644 --- a/test/integration_tests/long/test_reddit_datasource.py +++ b/test/integration_tests/long/test_reddit_datasource.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest - from test.markers import reddit_skip_marker from test.util import get_evadb_for_testing @@ -38,9 +37,9 @@ def test_should_run_select_query_on_reddit(self): # Create database. params = { "subreddit": "cricket", - "client_id": 'clientid..', - "client_secret": 'clientsecret..', - "user_agent": 'test script for dev eva' + "client_id": "clientid..", + "client_secret": "clientsecret..", + "user_agent": "test script for dev eva", } query = f"""CREATE DATABASE reddit_data WITH ENGINE = "reddit", diff --git a/test/markers.py b/test/markers.py index ad3fa1169..070b885fe 100644 --- a/test/markers.py +++ b/test/markers.py @@ -118,6 +118,4 @@ is_replicate_available() is False, reason="requires replicate" ) -reddit_skip_marker = pytest.mark.skip( - reason="requires Reddit secret key" -) \ No newline at end of file +reddit_skip_marker = pytest.mark.skip(reason="requires Reddit secret key")