From 181bc928843c1f769641514e9f6dbc41a4a3a91e Mon Sep 17 00:00:00 2001 From: Guillem Borrell Date: Mon, 17 Jun 2024 22:18:18 +0200 Subject: [PATCH] Successfully implemented function calling for anyscale --- notebooks/tasks.ipynb | 137 ++++++++++++++++++++++++++++++++++++++++++ requirements.in | 3 +- 2 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 notebooks/tasks.ipynb diff --git a/notebooks/tasks.ipynb b/notebooks/tasks.ipynb new file mode 100644 index 0000000..9c5b010 --- /dev/null +++ b/notebooks/tasks.ipynb @@ -0,0 +1,137 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "from hellocomputer.config import settings\n", + "from langchain_core.utils.function_calling import convert_to_openai_function\n", + "import openai\n", + "from operator import itemgetter" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_core.tools import tool\n", + "\n", + "\n", + "@tool\n", + "def add(a: int, b: int) -> int:\n", + " \"\"\"Adds a and b.\"\"\"\n", + " return a + b\n", + "\n", + "\n", + "@tool\n", + "def multiply(a: int, b: int) -> int:\n", + " \"\"\"Multiplies a and b.\"\"\"\n", + " return a * b\n", + "\n", + "\n", + "tools = [convert_to_openai_function(t) for t in [add, multiply]]\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "tools_fmt = [\n", + " {\"type\": \"function\",\n", + " \"function\": tools[0]},\n", + " {\"type\": \"function\",\n", + " \"function\": tools[1]}\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "messages = [\n", + " {\"role\": \"system\", \"content\": \"You are helpful assistant.\"},\n", + " {\"role\": \"user\", \"content\": \"What is 2 + 2?\"},\n", + "]\n", + "\n", + "client = openai.OpenAI(\n", + " base_url = \"https://api.endpoints.anyscale.com/v1\",\n", + " api_key = settings.anyscale_api_key\n", + ")\n", + "\n", + "response = client.chat.completions.create(\n", + " model=\"mistralai/Mixtral-8x7B-Instruct-v0.1\",\n", + " messages=messages,\n", + " tools=tools_fmt,\n", + " tool_choice=\"auto\", # auto is default, but we'll be explicit\n", + ")\n", + "\n", + "get_args = itemgetter(\"arguments\")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "call = response.choices[0].message.tool_calls[0].function" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "add.func(2,3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/requirements.in b/requirements.in index 9521c80..065f1ee 100644 --- a/requirements.in +++ b/requirements.in @@ -13,4 +13,5 @@ pyjwt[crypto] python-multipart authlib itsdangerous -sqlalchemy \ No newline at end of file +sqlalchemy +openai \ No newline at end of file