Skip to content

Commit b9f7b1b

Browse files
authored
Merge pull request #68 from AnswerDotAI/fix/array-for-gemini
Add 'items' field to JSON schema for Gemini compatibility
2 parents e9073e2 + 060aae7 commit b9f7b1b

File tree

2 files changed

+63
-11
lines changed

2 files changed

+63
-11
lines changed

01_funccall.ipynb

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@
193193
" if t is empty: raise TypeError('Missing type')\n",
194194
" tmap = {int:\"integer\", float:\"number\", str:\"string\", bool:\"boolean\", list:\"array\", dict:\"object\"}\n",
195195
" tmap.update({k.__name__: v for k, v in tmap.items()})\n",
196-
" if getattr(t, '__origin__', None) in (list,tuple):\n",
196+
" if getattr(t, '__origin__', None) in (list,tuple,set):\n",
197197
" args = getattr(t, '__args__', None)\n",
198198
" item_type = \"object\" if not args else tmap.get(t.__args__[0].__name__, \"object\")\n",
199199
" return \"array\", item_type\n",
@@ -222,7 +222,10 @@
222222
{
223223
"data": {
224224
"text/plain": [
225-
"(('array', 'integer'), ('integer', None), ('integer', None))"
225+
"(('array', 'integer'),\n",
226+
" ('array', 'integer'),\n",
227+
" ('integer', None),\n",
228+
" ('integer', None))"
226229
]
227230
},
228231
"execution_count": null,
@@ -231,7 +234,7 @@
231234
}
232235
],
233236
"source": [
234-
"_types(list[int]), _types(int), _types('int')"
237+
"_types(list[int]), _types(set[int]), _types(int), _types('int')"
235238
]
236239
},
237240
{
@@ -435,9 +438,11 @@
435438
"\n",
436439
"def _handle_type(t, defs):\n",
437440
" \"Handle a single type, creating nested schemas if necessary\"\n",
441+
" ot = ifnone(get_origin(t), t)\n",
438442
" if t is NoneType: return {'type': 'null'}\n",
439443
" if t in custom_types: return {'type':'string', 'format':t.__name__}\n",
440-
" if t in (dict, list, tuple, set): return {'type': _types(t)[0]}\n",
444+
" if ot is dict: return {'type': _types(t)[0]} \n",
445+
" if ot in (list, tuple, set): return {'type': _types(t)[0], 'items':{}}\n",
441446
" if isinstance(t, type) and not issubclass(t, (int, float, str, bool)) or inspect.isfunction(t):\n",
442447
" defs[t.__name__] = _get_nested_schema(t)\n",
443448
" return {'$ref': f'#/$defs/{t.__name__}'}\n",
@@ -465,6 +470,51 @@
465470
"_handle_type(int, None), _handle_type(Path, None)"
466471
]
467472
},
473+
{
474+
"cell_type": "code",
475+
"execution_count": null,
476+
"id": "a43e9134",
477+
"metadata": {},
478+
"outputs": [
479+
{
480+
"data": {
481+
"text/plain": [
482+
"({'type': 'array', 'items': {}},\n",
483+
" {'type': 'array', 'items': {}},\n",
484+
" {'type': 'array', 'items': {}})"
485+
]
486+
},
487+
"execution_count": null,
488+
"metadata": {},
489+
"output_type": "execute_result"
490+
}
491+
],
492+
"source": [
493+
"# gemini expect `items` to be defined for arrays\n",
494+
"_handle_type(list, None), _handle_type(tuple[str], None), _handle_type(set[str], None)"
495+
]
496+
},
497+
{
498+
"cell_type": "code",
499+
"execution_count": null,
500+
"id": "cf6d417e",
501+
"metadata": {},
502+
"outputs": [
503+
{
504+
"data": {
505+
"text/plain": [
506+
"({'type': 'object'}, {'type': 'object'})"
507+
]
508+
},
509+
"execution_count": null,
510+
"metadata": {},
511+
"output_type": "execute_result"
512+
}
513+
],
514+
"source": [
515+
"_handle_type(dict, None), _handle_type(dict[str,str], None)"
516+
]
517+
},
468518
{
469519
"cell_type": "code",
470520
"execution_count": null,
@@ -1710,14 +1760,14 @@
17101760
"output_type": "stream",
17111761
"text": [
17121762
"Traceback (most recent call last):\n",
1713-
" File \"/var/folders/51/b2_szf2945n072c0vj2cyty40000gn/T/ipykernel_6265/2052945749.py\", line 14, in python\n",
1763+
" File \"/tmp/ipykernel_3890/2052945749.py\", line 14, in python\n",
17141764
" try: return _run(code, glb, loc)\n",
17151765
" ^^^^^^^^^^^^^^^^^^^^\n",
1716-
" File \"/var/folders/51/b2_szf2945n072c0vj2cyty40000gn/T/ipykernel_6265/1858893181.py\", line 18, in _run\n",
1766+
" File \"/tmp/ipykernel_3890/1858893181.py\", line 18, in _run\n",
17171767
" try: exec(compiled_code, glb, loc)\n",
17181768
" ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n",
17191769
" File \"<ast>\", line 1, in <module>\n",
1720-
" File \"/var/folders/51/b2_szf2945n072c0vj2cyty40000gn/T/ipykernel_6265/2052945749.py\", line 9, in handler\n",
1770+
" File \"/tmp/ipykernel_3890/2052945749.py\", line 9, in handler\n",
17211771
" def handler(*args): raise TimeoutError()\n",
17221772
" ^^^^^^^^^^^^^^^^^^^^\n",
17231773
"TimeoutError\n",
@@ -2051,7 +2101,7 @@
20512101
{
20522102
"cell_type": "code",
20532103
"execution_count": null,
2054-
"id": "7ac04e80-7bb9-4b52-8285-454684605d47",
2104+
"id": "73bca085",
20552105
"metadata": {},
20562106
"outputs": [],
20572107
"source": [

toolslm/funccall.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _types(t:type)->tuple[str,Optional[str]]:
2525
if t is empty: raise TypeError('Missing type')
2626
tmap = {int:"integer", float:"number", str:"string", bool:"boolean", list:"array", dict:"object"}
2727
tmap.update({k.__name__: v for k, v in tmap.items()})
28-
if getattr(t, '__origin__', None) in (list,tuple):
28+
if getattr(t, '__origin__', None) in (list,tuple,set):
2929
args = getattr(t, '__args__', None)
3030
item_type = "object" if not args else tmap.get(t.__args__[0].__name__, "object")
3131
return "array", item_type
@@ -58,9 +58,11 @@ def _param(
5858

5959
def _handle_type(t, defs):
6060
"Handle a single type, creating nested schemas if necessary"
61+
ot = ifnone(get_origin(t), t)
6162
if t is NoneType: return {'type': 'null'}
6263
if t in custom_types: return {'type':'string', 'format':t.__name__}
63-
if t in (dict, list, tuple, set): return {'type': _types(t)[0]}
64+
if ot is dict: return {'type': _types(t)[0]}
65+
if ot in (list, tuple, set): return {'type': _types(t)[0], 'items':{}}
6466
if isinstance(t, type) and not issubclass(t, (int, float, str, bool)) or inspect.isfunction(t):
6567
defs[t.__name__] = _get_nested_schema(t)
6668
return {'$ref': f'#/$defs/{t.__name__}'}
@@ -235,7 +237,7 @@ def call_func(fc_name, fc_inputs, ns, raise_on_err=True):
235237
if raise_on_err: raise e from None
236238
else: return traceback.format_exc()
237239

238-
# %% ../01_funccall.ipynb #7ac04e80-7bb9-4b52-8285-454684605d47
240+
# %% ../01_funccall.ipynb #73bca085
239241
async def call_func_async(fc_name, fc_inputs, ns, raise_on_err=True):
240242
"Awaits the function `fc_name` with the given `fc_inputs` using namespace `ns`."
241243
res = call_func(fc_name, fc_inputs, ns, raise_on_err=raise_on_err)

0 commit comments

Comments
 (0)