diff --git a/api/index.py b/api/index.py index 810ea3c7..4bc38129 100644 --- a/api/index.py +++ b/api/index.py @@ -1,12 +1,13 @@ import os import json +import uuid from typing import List from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from pydantic import BaseModel from dotenv import load_dotenv from fastapi import FastAPI, Query from fastapi.responses import StreamingResponse -from openai import OpenAI +from openai import AsyncOpenAI from .utils.prompt import ClientMessage, convert_to_openai_messages from .utils.tools import get_current_weather @@ -15,7 +16,7 @@ app = FastAPI() -client = OpenAI( +client = AsyncOpenAI( api_key=os.environ.get("OPENAI_API_KEY"), ) @@ -28,126 +29,144 @@ class Request(BaseModel): "get_current_weather": get_current_weather, } -def do_stream(messages: List[ChatCompletionMessageParam]): - stream = client.chat.completions.create( - messages=messages, - model="gpt-4o", - stream=True, - tools=[{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather at a location", - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "number", - "description": "The latitude of the location", - }, - "longitude": { - "type": "number", - "description": "The longitude of the location", - }, - }, - "required": ["latitude", "longitude"], - }, - }, - }] - ) - - return stream - -def stream_text(messages: List[ChatCompletionMessageParam], protocol: str = 'data'): - draft_tool_calls = [] - draft_tool_calls_index = -1 - - stream = client.chat.completions.create( - messages=messages, - model="gpt-4o", - stream=True, - tools=[{ - "type": "function", - "function": { - "name": "get_current_weather", - "description": "Get the current weather at a location", - "parameters": { - "type": "object", - "properties": { - "latitude": { - "type": "number", - "description": "The latitude of the location", - }, - "longitude": { - "type": "number", - "description": "The longitude of the location", +async def stream_text(messages: List[ChatCompletionMessageParam], protocol: str = 'data'): + message_id = f"msg_{uuid.uuid4().hex}" + + yield f'data: {json.dumps({"type": "start", "messageId": message_id})}\n\n' + + conversation_messages = list(messages) + + while True: + text_id = f"text_{uuid.uuid4().hex}" + text_started = False + draft_tool_calls = [] + draft_tool_calls_index = -1 + + stream = await client.chat.completions.create( + messages=conversation_messages, + model="gpt-4o", + stream=True, + tools=[{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather at a location", + "parameters": { + "type": "object", + "properties": { + "latitude": { + "type": "number", + "description": "The latitude of the location", + }, + "longitude": { + "type": "number", + "description": "The longitude of the location", + }, }, + "required": ["latitude", "longitude"], }, - "required": ["latitude", "longitude"], }, - }, - }] - ) - - for chunk in stream: - for choice in chunk.choices: - if choice.finish_reason == "stop": - continue - - elif choice.finish_reason == "tool_calls": - for tool_call in draft_tool_calls: - yield '9:{{"toolCallId":"{id}","toolName":"{name}","args":{args}}}\n'.format( - id=tool_call["id"], - name=tool_call["name"], - args=tool_call["arguments"]) - - for tool_call in draft_tool_calls: - tool_result = available_tools[tool_call["name"]]( - **json.loads(tool_call["arguments"])) - - yield 'a:{{"toolCallId":"{id}","toolName":"{name}","args":{args},"result":{result}}}\n'.format( - id=tool_call["id"], - name=tool_call["name"], - args=tool_call["arguments"], - result=json.dumps(tool_result)) - - elif choice.delta.tool_calls: - for tool_call in choice.delta.tool_calls: - id = tool_call.id - name = tool_call.function.name - arguments = tool_call.function.arguments - - if (id is not None): - draft_tool_calls_index += 1 - draft_tool_calls.append( - {"id": id, "name": name, "arguments": ""}) - - else: - draft_tool_calls[draft_tool_calls_index]["arguments"] += arguments - - else: - yield '0:{text}\n'.format(text=json.dumps(choice.delta.content)) - - if chunk.choices == []: - usage = chunk.usage - prompt_tokens = usage.prompt_tokens - completion_tokens = usage.completion_tokens - - yield 'e:{{"finishReason":"{reason}","usage":{{"promptTokens":{prompt},"completionTokens":{completion}}},"isContinued":false}}\n'.format( - reason="tool-calls" if len( - draft_tool_calls) > 0 else "stop", - prompt=prompt_tokens, - completion=completion_tokens - ) + }] + ) + + finish_reason = None + + async for chunk in stream: + for choice in chunk.choices: + if choice.delta.tool_calls: + for tool_call in choice.delta.tool_calls: + id = tool_call.id + name = tool_call.function.name + arguments = tool_call.function.arguments + + if (id is not None): + draft_tool_calls_index += 1 + draft_tool_calls.append( + {"id": id, "name": name, "arguments": ""}) + + yield f'data: {json.dumps({"type": "tool-input-start", "toolCallId": id, "toolName": name})}\n\n' + + if arguments: + draft_tool_calls[draft_tool_calls_index]["arguments"] += arguments + yield f'data: {json.dumps({"type": "tool-input-delta", "toolCallId": draft_tool_calls[draft_tool_calls_index]["id"], "inputTextDelta": arguments})}\n\n' + + if choice.delta.content: + if not text_started: + yield f'data: {json.dumps({"type": "text-start", "id": text_id})}\n\n' + text_started = True + + yield f'data: {json.dumps({"type": "text-delta", "id": text_id, "delta": choice.delta.content})}\n\n' + + if choice.finish_reason: + finish_reason = choice.finish_reason + + if text_started: + yield f'data: {json.dumps({"type": "text-end", "id": text_id})}\n\n' + text_started = False + + if finish_reason == "tool_calls": + tool_calls_for_message = [ + { + "id": tc["id"], + "type": "function", + "function": { + "name": tc["name"], + "arguments": tc["arguments"] + } + } + for tc in draft_tool_calls + ] + + conversation_messages.append({ + "role": "assistant", + "tool_calls": tool_calls_for_message + }) + + for tool_call in draft_tool_calls: + parsed_args = json.loads(tool_call["arguments"]) + + yield f'data: {json.dumps({"type": "tool-input-available", "toolCallId": tool_call["id"], "toolName": tool_call["name"], "input": parsed_args})}\n\n' + + tool_result = available_tools[tool_call["name"]](**parsed_args) + + yield f'data: {json.dumps({"type": "tool-output-available", "toolCallId": tool_call["id"], "output": tool_result})}\n\n' + + conversation_messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "content": json.dumps(tool_result) + }) + + yield f'data: {json.dumps({"type": "finish-step"})}\n\n' + continue + + elif finish_reason == "stop": + break + + yield f'data: {json.dumps({"type": "finish"})}\n\n' + yield f'data: [DONE]\n\n' @app.post("/api/chat") async def handle_chat_data(request: Request, protocol: str = Query('data')): - messages = request.messages - openai_messages = convert_to_openai_messages(messages) - - response = StreamingResponse(stream_text(openai_messages, protocol)) - response.headers['x-vercel-ai-data-stream'] = 'v1' - return response + try: + messages = request.messages + openai_messages = convert_to_openai_messages(messages) + + return StreamingResponse( + stream_text(openai_messages, protocol), + media_type="text/event-stream", + headers={ + 'x-vercel-ai-ui-message-stream': 'v1', + 'Cache-Control': 'no-cache, no-transform', + 'X-Accel-Buffering': 'no', + 'Connection': 'keep-alive', + 'Content-Type': 'text/event-stream', + } + ) + except Exception as e: + import traceback + traceback.print_exc() + raise diff --git a/api/utils/prompt.py b/api/utils/prompt.py index f6b905e7..f36cbebb 100644 --- a/api/utils/prompt.py +++ b/api/utils/prompt.py @@ -1,85 +1,76 @@ import json -from enum import Enum from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from pydantic import BaseModel -import base64 -from typing import List, Optional, Any -from .attachment import ClientAttachment - -class ToolInvocationState(str, Enum): - CALL = 'call' - PARTIAL_CALL = 'partial-call' - RESULT = 'result' - -class ToolInvocation(BaseModel): - state: ToolInvocationState - toolCallId: str - toolName: str - args: Any - result: Any - +from typing import List, Any +# AI SDK v5 UIMessage format class ClientMessage(BaseModel): + id: str role: str - content: str - experimental_attachments: Optional[List[ClientAttachment]] = None - toolInvocations: Optional[List[ToolInvocation]] = None + parts: List[Any] # v5 parts array (required) def convert_to_openai_messages(messages: List[ClientMessage]) -> List[ChatCompletionMessageParam]: + """ + Convert AI SDK v5 UIMessages to OpenAI message format. + """ openai_messages = [] for message in messages: - parts = [] + content_parts = [] tool_calls = [] - parts.append({ - 'type': 'text', - 'text': message.content - }) - - if (message.experimental_attachments): - for attachment in message.experimental_attachments: - if (attachment.contentType.startswith('image')): - parts.append({ - 'type': 'image_url', - 'image_url': { - 'url': attachment.url - } - }) - - elif (attachment.contentType.startswith('text')): - parts.append({ + # Process v5 parts array + for part in message.parts: + if isinstance(part, dict): + part_type = part.get('type') + + # Text parts + if part_type == 'text': + content_parts.append({ 'type': 'text', - 'text': attachment.url + 'text': part.get('text', '') + }) + + # File/Image parts + elif part_type == 'file': + media_type = part.get('mediaType', '') + if media_type.startswith('image'): + content_parts.append({ + 'type': 'image_url', + 'image_url': { + 'url': part.get('url') + } + }) + + # Tool call parts (for assistant messages) + elif part_type == 'tool-call': + tool_calls.append({ + "id": part.get('toolCallId'), + "type": "function", + "function": { + "name": part.get('toolName'), + "arguments": json.dumps(part.get('input', {})) + } }) - if(message.toolInvocations): - for toolInvocation in message.toolInvocations: - tool_calls.append({ - "id": toolInvocation.toolCallId, - "type": "function", - "function": { - "name": toolInvocation.toolName, - "arguments": json.dumps(toolInvocation.args) - } - }) - - tool_calls_dict = {"tool_calls": tool_calls} if tool_calls else {"tool_calls": None} - - openai_messages.append({ - "role": message.role, - "content": parts, - **tool_calls_dict, - }) + # Add message with content and/or tool calls + if content_parts or tool_calls: + msg = { + "role": message.role, + "content": content_parts if content_parts else None, + } + if tool_calls: + msg["tool_calls"] = tool_calls + openai_messages.append(msg) - if(message.toolInvocations): - for toolInvocation in message.toolInvocations: + # Add tool result messages (separate messages for tool outputs) + for part in message.parts: + if isinstance(part, dict) and part.get('type') == 'tool-call' and part.get('state') == 'result': tool_message = { "role": "tool", - "tool_call_id": toolInvocation.toolCallId, - "content": json.dumps(toolInvocation.result), + "tool_call_id": part.get('toolCallId'), + "content": json.dumps(part.get('result')), } - openai_messages.append(tool_message) return openai_messages diff --git a/components/chat.tsx b/components/chat.tsx index 82bcb6dd..657f700f 100644 --- a/components/chat.tsx +++ b/components/chat.tsx @@ -4,28 +4,19 @@ import { PreviewMessage, ThinkingMessage } from "@/components/message"; import { MultimodalInput } from "@/components/multimodal-input"; import { Overview } from "@/components/overview"; import { useScrollToBottom } from "@/hooks/use-scroll-to-bottom"; -import { ToolInvocation } from "ai"; -import { useChat } from "ai/react"; +import { useChat, type UIMessage } from "@ai-sdk/react"; import { toast } from "sonner"; +import React from "react"; export function Chat() { const chatId = "001"; - const { - messages, - setMessages, - handleSubmit, - input, - setInput, - append, - isLoading, - stop, - } = useChat({ - maxSteps: 4, - onError: (error) => { + const { messages, setMessages, sendMessage, status, stop } = useChat({ + id: chatId, + onError: (error: Error) => { if (error.message.includes("Too many requests")) { toast.error( - "You are sending too many messages. Please try again later.", + "You are sending too many messages. Please try again later." ); } }, @@ -34,6 +25,26 @@ export function Chat() { const [messagesContainerRef, messagesEndRef] = useScrollToBottom(); + // V5: Manage input state manually + const [input, setInput] = React.useState(""); + + const isLoading = status === "submitted" || status === "streaming"; + + const handleSubmit = (event?: { preventDefault?: () => void }) => { + event?.preventDefault?.(); + if (input.trim()) { + sendMessage({ text: input }); + setInput(""); + } + }; + + const append = async (message: any): Promise => { + if (message.content) { + sendMessage({ text: message.content }); + } + return undefined; + }; + return (
{messages.length === 0 && } - {messages.map((message, index) => ( + {messages.map((message: UIMessage, index: number) => ( { return ( @@ -25,7 +25,7 @@ export const PreviewMessage = ({ >
{message.role === "assistant" && ( @@ -35,54 +35,62 @@ export const PreviewMessage = ({ )}
- {message.content && ( -
- {message.content as string} -
- )} - - {message.toolInvocations && message.toolInvocations.length > 0 && ( -
- {message.toolInvocations.map((toolInvocation) => { - const { toolName, toolCallId, state } = toolInvocation; - - if (state === "result") { - const { result } = toolInvocation; + {message.parts && + message.parts.map((part: any, index: number) => { + if (part.type === "text") { + return ( +
+ {part.text} +
+ ); + } + // Handle tool calls - type is "tool-{toolName}" in AI SDK v5 + if (part.type?.startsWith("tool-")) { + const { toolCallId, state, output } = part; + const toolName = part.type.replace("tool-", ""); + if (state === "output-available" && output) { return (
{toolName === "get_current_weather" ? ( - + ) : ( -
{JSON.stringify(result, null, 2)}
+
{JSON.stringify(output, null, 2)}
)}
); } + // Show loading state while tool is executing + if ( + state === "input-streaming" || + state === "input-available" + ) { + return ( +
+ {toolName === "get_current_weather" ? : null} +
+ ); + } + } + if (part.type === "file") { return ( -
- {toolName === "get_current_weather" ? : null} -
+ ); - })} -
- )} - - {message.experimental_attachments && ( -
- {message.experimental_attachments.map((attachment) => ( - - ))} -
- )} + } + return null; + })}
@@ -104,7 +112,7 @@ export const ThinkingMessage = () => { "flex gap-4 group-data-[role=user]/message:px-3 w-full group-data-[role=user]/message:w-fit group-data-[role=user]/message:ml-auto group-data-[role=user]/message:max-w-2xl group-data-[role=user]/message:py-2 rounded-xl", { "group-data-[role=user]/message:bg-muted": true, - }, + } )} >
diff --git a/components/multimodal-input.tsx b/components/multimodal-input.tsx index 58bac1a5..1532e1da 100644 --- a/components/multimodal-input.tsx +++ b/components/multimodal-input.tsx @@ -1,6 +1,12 @@ "use client"; -import type { ChatRequestOptions, CreateMessage, Message } from "ai"; +import type { CreateUIMessage, UIMessage } from "@ai-sdk/react"; + +type ChatRequestOptions = { + headers?: Record | Headers; + body?: object; + data?: any; +}; import { motion } from "framer-motion"; import type React from "react"; import { @@ -49,17 +55,17 @@ export function MultimodalInput({ setInput: (value: string) => void; isLoading: boolean; stop: () => void; - messages: Array; - setMessages: Dispatch>>; + messages: Array; + setMessages: Dispatch>>; append: ( - message: Message | CreateMessage, - chatRequestOptions?: ChatRequestOptions, + message: UIMessage | CreateUIMessage, + chatRequestOptions?: ChatRequestOptions ) => Promise; handleSubmit: ( event?: { preventDefault?: () => void; }, - chatRequestOptions?: ChatRequestOptions, + chatRequestOptions?: ChatRequestOptions ) => void; className?: string; }) { @@ -75,13 +81,15 @@ export function MultimodalInput({ const adjustHeight = () => { if (textareaRef.current) { textareaRef.current.style.height = "auto"; - textareaRef.current.style.height = `${textareaRef.current.scrollHeight + 2}px`; + textareaRef.current.style.height = `${ + textareaRef.current.scrollHeight + 2 + }px`; } }; const [localStorageInput, setLocalStorageInput] = useLocalStorage( "input", - "", + "" ); useEffect(() => { @@ -150,11 +158,11 @@ export function MultimodalInput({