{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Running RolmOCR on Vast\n", "\n", "[RolmOCR](https://huggingface.co/reducto/RolmOCR) from Reducto is a powerful, open-source document OCR solution that delivers superior performance while requiring fewer resources than comparable models. Built on Qwen2.5-VL-7B, this model excels at parsing complex documents including PDFs, invoices, and forms without requiring metadata extraction. Because it is open source, companies can build on top of this model for their own proprietary pipelines while not sending data to other providers or model hosting services.\n", "\n", "Vast.ai offers a GPU marketplace where you can rent compute power at lower costs than major cloud providers, with the flexibility to select specific hardware configurations optimized for specific models, while keeping all of your company's data private.\n", "\n", "This notebook demonstrates how to extract structured pricing data from invoice images using `reducto/RolmOCR`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Deploy `reducto/RolmOCR` on Vast\n", "\n", "### Install Vast\n", "\n", "First, we will install and set up the Vast API. \n", "\n", "You can get your API key on the [Account Page](https://cloud.vast.ai/account/) in the Vast Console and set it below in `VAST_API_KEY`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "pip install vastai==0.2.6" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash\n", "export VAST_API_KEY=\"\" #Your key here\n", "vastai set api-key $VAST_API_KEY" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Search for an Instance\n", "\n", "Next, we'll search for an instance to host our model. While `reducto/RolmOCR` requires at least 16GB VRAM, we'll select an instance with 60GB VRAM to accommodate larger documents and enable a wider context window." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "%%bash\n", "vastai search offers \"compute_cap >= 750 \\\n", "geolocation=US \\\n", "gpu_ram >= 60 \\\n", "num_gpus = 1 \\\n", "static_ip = true \\\n", "direct_port_count >= 1 \\\n", "verified = true \\\n", "disk_space >= 80 \\\n", "rentable = true\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Deploy our Instance\n", "\n", "Finally, we will use the instance ID from our search to deploy our model to our instance. \n", "\n", "Note: we set `VLLM_USE_V1=1` to use the v1 engine for vLLM, which `reducto/RolmOCR` requires." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "shellscript" } }, "outputs": [], "source": [ "%%bash\n", "export INSTANCE_ID= #insert instance ID\n", "vastai create instance $INSTANCE_ID --image vllm/vllm-openai:latest --env '-p 8000:8000 -e VLLM_USE_V1=1' --disk 80 --args --model reducto/RolmOCR" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using the OpenAI API to Call `reducto/RolmOCR`\n", "\n", "\n", "### Download dependencies\n", "First, we will install our dependencies." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%%bash \n", "pip install --upgrade openai datasets pydantic " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "### Download Sample Invoice Data\n", "\n", "We will use `datasets` to get a subset of the invoice data from the `katanaml-org/invoices-donut-data-v1` dataset on Hugging Face, which contains 500 annotated invoice images with structured metadata for training document extraction models." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset\n", "\n", "# Stream the dataset\n", "streamed_dataset = load_dataset(\"katanaml-org/invoices-donut-data-v1\", split=\"train\", streaming=True)\n", "\n", "# Take the first 3 samples\n", "subset = list(streamed_dataset.take(3))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will then create an `encode_pil_image` function to convert the images from our dataset to a base64 encoded image to be passed into the OpenAI API." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "import base64\n", "import io\n", "from PIL import Image\n", "\n", "def encode_pil_image(pil_image):\n", " # Resize image to a smaller size while maintaining aspect ratio\n", " max_size = 1024 # Maximum dimension\n", " ratio = min(max_size / pil_image.width, max_size / pil_image.height)\n", " new_size = (int(pil_image.width * ratio), int(pil_image.height * ratio))\n", " resized_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)\n", " \n", " # Convert PIL Image to bytes\n", " img_byte_arr = io.BytesIO()\n", " resized_image.save(img_byte_arr, format='JPEG', quality=85) # Reduced quality for smaller size\n", " img_byte_arr = img_byte_arr.getvalue()\n", " return base64.b64encode(img_byte_arr).decode(\"utf-8\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Enforcing Structured Data Extraction\n", "\n", "We'll define an `Invoice` schema with Pydantic to ensure our model returns precisely formatted data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from pydantic import BaseModel\n", "\n", "class Invoice(BaseModel):\n", " invoice_number: str\n", " invoice_amount: str\n", "\n", "\n", "json_schema = Invoice.model_json_schema()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Configuring the API Client\n", "\n", "Next, we'll create our `ocr_page_with_rolm` function that interfaces with the RolmOCR endpoint.\n", "\n", "We'll also add our `VAST_IP_ADDRESS` and `VAST_PORT` for our running instance. We'll find these in the [Instances tab](https://cloud.vast.ai/instances/) of our Vast AI Console." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "from openai import OpenAI\n", "\n", "VAST_IP_ADDRESS = \"\"\n", "VAST_PORT = \"\"\n", "base_url = f\"http://{VAST_IP_ADDRESS}:{VAST_PORT}/v1\"\n", "\n", "\n", "client = OpenAI(api_key=\"\", base_url=base_url)\n", "\n", "model = \"reducto/RolmOCR\"\n", "\n", "def ocr_page_with_rolm(img_base64):\n", " response = client.chat.completions.create(\n", " model=model,\n", " messages=[\n", " {\n", " \"role\": \"user\",\n", " \"content\": [\n", " {\n", " \"type\": \"image_url\",\n", " \"image_url\": {\"url\": f\"data:image/png;base64,{img_base64}\"},\n", " },\n", " {\n", " \"type\": \"text\",\n", " \"text\": \"Return the invoice number and total amount for each invoice as a json: {invoice_number : str, invoice_amount: str}\",\n", " },\n", " ],\n", " }\n", " ],\n", " extra_body={\"guided_json\": json_schema},\n", " temperature=0.2,\n", " max_tokens=500\n", " )\n", " return response.choices[0].message.content" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we will iterate over our subset to extract `invoice_number` and `invoice_amount`. We will display the original invoice and compare our json output to the ground truth data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import json\n", "\n", "invoices = []\n", "ground_truth = []\n", "for sample in subset:\n", " # Display the image\n", " plt.figure(figsize=(10, 14))\n", " plt.imshow(sample['image'])\n", " plt.axis('off')\n", " plt.show()\n", " \n", " # Process with OCR\n", " img_base64 = encode_pil_image(sample['image'])\n", " result = ocr_page_with_rolm(img_base64)\n", " result_dict = json.loads(result)\n", " invoices.append(result_dict)\n", "\n", " ground_truth_i = json.loads(sample[\"ground_truth\"])\n", " ground_truth_dict = {\n", " \"invoice_number\":ground_truth_i[\"gt_parse\"][\"header\"][\"invoice_no\"],\n", " \"invoice_amount\":ground_truth_i[\"gt_parse\"][\"summary\"][\"total_gross_worth\"]\n", " }\n", " ground_truth.append(ground_truth_dict)\n", "\n", " print(\"Ground Truth\")\n", " print(ground_truth_dict)\n", "\n", " print(\"Extracted Info\")\n", " print(result_dict)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Output\n", "\n", "![](./img/output_1.png)\n", "\n", "```\n", "Ground Truth\n", "{'invoice_number': '40378170', 'invoice_amount': '$8,25'}\n", "Extracted Info\n", "{'invoice_number': '40378170', 'invoice_amount': '$8.25'}\n", "```\n", "\n", "\n", "![](./img/output_2.png)\n", "\n", "\n", "```\n", "Ground Truth\n", "{'invoice_number': '61356291', 'invoice_amount': '$ 212,09'}\n", "Extracted Info\n", "{'invoice_number': '61356291', 'invoice_amount': '$212.09'}\n", "```\n", "\n", "\n", "![](./img/output_3.png)\n", "\n", "```\n", "Ground Truth\n", "{'invoice_number': '49565075', 'invoice_amount': '$96,73'}\n", "Extracted Info\n", "{'invoice_number': '49565075', 'invoice_amount': '$96,73'}\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We see that the output from `reducto/RolmOCR` matches the expected output from the ground truth associated with each invoice." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Conclusion\n", "\n", "In this notebook, we've demonstrated how to deploy and use RolmOCR on Vast to extract structured data from invoice images. \n", "\n", "RolmOCR proves to be highly accurate at identifying key information like invoice numbers and amounts. The combination of RolmOCR's efficiency with Vast.ai's cost-effective GPU options makes this an excellent solution for document processing workflows at scale. \n", "\n", "This approach can be extended to extract other types of structured data from various document formats, enabling powerful automation capabilities for businesses of all sizes.\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 2 }