|
11 | 11 | import pandas as pd |
12 | 12 | import requests |
13 | 13 | import uvicorn |
14 | | -from fastapi import FastAPI, Request |
| 14 | +from fastapi import FastAPI |
15 | 15 | from fastapi.exceptions import RequestValidationError |
16 | 16 | from fastapi.openapi.utils import get_openapi |
17 | 17 | from fastapi.responses import HTMLResponse, PlainTextResponse, RedirectResponse |
18 | | - |
19 | 18 | from .helpers import api_data_to_frame, response_to_frame |
| 19 | +from .handlers.sklearn import SKLearnHandler |
20 | 20 | from .meta import VetiverMeta |
21 | 21 | from .utils import _jupyter_nb, get_workbench_path |
22 | 22 | from .vetiver_model import VetiverModel |
| 23 | +from .types import SklearnPredictionTypes |
23 | 24 |
|
24 | 25 |
|
25 | 26 | class VetiverAPI: |
@@ -111,7 +112,6 @@ async def startup_event(): |
111 | 112 |
|
112 | 113 | @app.get("/", include_in_schema=False) |
113 | 114 | def docs_redirect(): |
114 | | - |
115 | 115 | redirect = "__docs__" |
116 | 116 |
|
117 | 117 | return RedirectResponse(redirect) |
@@ -200,65 +200,75 @@ async def validation_exception_handler(request, exc): |
200 | 200 |
|
201 | 201 | return app |
202 | 202 |
|
203 | | - def vetiver_post(self, endpoint_fx: Callable, endpoint_name: str = None, **kw): |
204 | | - """Create new POST endpoint that is aware of model input data |
| 203 | + def vetiver_post( |
| 204 | + self, |
| 205 | + endpoint_fx: Union[Callable, SklearnPredictionTypes], |
| 206 | + endpoint_name: str = None, |
| 207 | + **kw, |
| 208 | + ): |
| 209 | + """Define a new POST endpoint that utilizes the model's input data. |
205 | 210 |
|
206 | 211 | Parameters |
207 | 212 | ---------- |
208 | | - endpoint_fx : typing.Callable |
209 | | - Custom function to be run at endpoint |
| 213 | + endpoint_fx : Union[typing.Callable, Literal["predict", "predict_proba", "predict_log_proba"]] |
| 214 | + A callable function that specifies the custom logic to execute when the endpoint is called. |
| 215 | + This function should take input data (e.g., a DataFrame or dictionary) and return the desired output |
| 216 | + (e.g., predictions or transformed data). For scikit-learn models, endpoint_fx can also be one of |
| 217 | + "predict", "predict_proba", or "predict_log_proba" if the model supports these methods. |
| 218 | +
|
210 | 219 | endpoint_name : str |
211 | | - Name of endpoint |
| 220 | + The name of the endpoint to be created. |
212 | 221 |
|
213 | 222 | Examples |
214 | 223 | ------- |
215 | | - ```{python} |
| 224 | + ```python |
216 | 225 | from vetiver import mock, VetiverModel, VetiverAPI |
217 | 226 | X, y = mock.get_mock_data() |
218 | 227 | model = mock.get_mock_model().fit(X, y) |
219 | 228 |
|
220 | | - v = VetiverModel(model = model, model_name = "model", prototype_data = X) |
221 | | - v_api = VetiverAPI(model = v, check_prototype = True) |
| 229 | + v = VetiverModel(model=model, model_name="model", prototype_data=X) |
| 230 | + v_api = VetiverAPI(model=v, check_prototype=True) |
222 | 231 |
|
223 | 232 | def sum_values(x): |
224 | 233 | return x.sum() |
| 234 | +
|
225 | 235 | v_api.vetiver_post(sum_values, "sums") |
226 | 236 | ``` |
227 | 237 | """ |
228 | | - if not endpoint_name: |
229 | | - endpoint_name = endpoint_fx.__name__ |
230 | 238 |
|
231 | | - if endpoint_fx.__doc__ is not None: |
232 | | - api_desc = dedent(endpoint_fx.__doc__) |
233 | | - else: |
234 | | - api_desc = None |
235 | | - |
236 | | - if self.check_prototype is True: |
237 | | - |
238 | | - @self.app.post( |
239 | | - urljoin("/", endpoint_name), |
240 | | - name=endpoint_name, |
241 | | - description=api_desc, |
| 239 | + if isinstance(endpoint_fx, SklearnPredictionTypes): |
| 240 | + if not isinstance(self.model, SKLearnHandler): |
| 241 | + raise ValueError( |
| 242 | + "The 'endpoint_fx' parameter can only be a string when using scikit-learn models." |
| 243 | + ) |
| 244 | + self.vetiver_post( |
| 245 | + self.model.handler_predict, |
| 246 | + SklearnPredictionTypes, |
| 247 | + check_prototype=self.check_prototype, |
| 248 | + prediction_type=endpoint_fx, |
242 | 249 | ) |
243 | | - async def custom_endpoint(input_data: List[self.model.prototype]): |
244 | | - _to_frame = api_data_to_frame(input_data) |
245 | | - predictions = endpoint_fx(_to_frame, **kw) |
246 | | - if isinstance(predictions, List): |
247 | | - return {endpoint_name: predictions} |
248 | | - else: |
249 | | - return predictions |
| 250 | + return |
250 | 251 |
|
251 | | - else: |
| 252 | + endpoint_name = endpoint_name or endpoint_fx.__name__ |
| 253 | + endpoint_doc = dedent(endpoint_fx.__doc__) if endpoint_fx.__doc__ else None |
252 | 254 |
|
253 | | - @self.app.post(urljoin("/", endpoint_name)) |
254 | | - async def custom_endpoint(input_data: Request): |
| 255 | + @self.app.post( |
| 256 | + urljoin("/", endpoint_name), |
| 257 | + name=endpoint_name, |
| 258 | + description=endpoint_doc, |
| 259 | + ) |
| 260 | + async def custom_endpoint(input_data: List[self.model.prototype]): |
| 261 | + if self.check_prototype: |
| 262 | + served_data = api_data_to_frame(input_data) |
| 263 | + else: |
255 | 264 | served_data = await input_data.json() |
256 | | - predictions = endpoint_fx(served_data, **kw) |
257 | 265 |
|
258 | | - if isinstance(predictions, List): |
259 | | - return {endpoint_name: predictions} |
260 | | - else: |
261 | | - return predictions |
| 266 | + predictions = endpoint_fx(served_data, **kw) |
| 267 | + |
| 268 | + if isinstance(predictions, List): |
| 269 | + return {endpoint_name: predictions} |
| 270 | + else: |
| 271 | + return predictions |
262 | 272 |
|
263 | 273 | def run(self, port: int = 8000, host: str = "127.0.0.1", quiet_open=False, **kw): |
264 | 274 | """ |
|
0 commit comments