Perform Database Queries using LLMs and LangChain

Published November 07, 2023

Intro

Introduction

In this article, we will explore how to use SQLCoder-7B, a Large Language Model (LLM) that we will deploy on Amazon SageMaker, along with LangChain to perform Natural Language Querying (NLQ).
We will see how to use LangChain to create a pipeline that prompts the LLM to generate an SQL query, retrieves data from a PostgreSQL database, and passes the results as a context to the LLM to obtain the final response.

SQLCoder

SQLCoder is a collection of Large Language Models (LLMs) created for an efficient generation of SQL queries from natural language.
We will use SQLCoder-7B, which is based on Mistral-7B and has been fine-tuned for SQL queries generation.
According to its creators: "SQLCoder-7B outperforms GPT-3.5 Turbo and other popular open-source models in natural language to SQL tasks. Additionally, it even surpasses GPT-4 when fine-tuned on a specific database schema."

Setting Up The Environment

Provisioning the Database

Let's begin by provisioning a PostgreSQL database using Amazon RDS.
We will use the following Terraform code snippet to accomplish this task:

# ------------------------------------------------------------------------------
# RDS Security group
# ------------------------------------------------------------------------------
resource "aws_security_group" "db_sg" {
  name_prefix = local.db_security_group_name_prefix
  vpc_id      = local.vpc_id

  ingress {
    from_port   = local.db_port
    to_port     = local.db_port
    protocol    = "tcp"
    cidr_blocks = [local.my_ip_address]
  }

  egress {
    from_port   = 0
    to_port     = 0
    protocol    = "-1"
    cidr_blocks = ["0.0.0.0/0"]
  }
}

# ------------------------------------------------------------------------------
# RDS
# ------------------------------------------------------------------------------
module "db" {
  source     = "terraform-aws-modules/rds/aws"
  identifier = local.db_identifier

  engine         = "postgres"
  engine_version = "15.4"
  family         = "postgres15"

  instance_class    = local.db_instance_class
  allocated_storage = local.db_allocated_storage

  db_name  = local.db_name
  username = local.db_username
  port     = local.db_port

  create_db_subnet_group = true
  vpc_security_group_ids = [aws_security_group.db_sg.id]
  subnet_ids             = local.db_subnet_ids

}

Initializing the Database

Now, let's populate our newly created database with data from the E-Commerce Data dataset, which you can download as a CSV file from Kaggle.
After downloading the dataset, connect to your RDS database using the following command:

psql -h <RDS_ENDPOINT> -p <RDS_PORT> -U <DATABASE_USERNAME> -d <DATABASE_NAME> -W

Enter your password when prompted. Once connected, create a sales table with the following SQL command:

CREATE TABLE sales (
    invoiceno VARCHAR(255),
    stockcode VARCHAR(255),
    description VARCHAR(255),
    quantity INT,
    invoicedate TIMESTAMP,
    unitprice DECIMAL(10, 2),
    customerid INT,
    country VARCHAR(50)
);

Next, copy the data from the CSV file into the sales table using the following command:

\COPY sales(invoiceno, stockcode, description, quantity, invoicedate, unitprice, customerid, country) FROM '/path/to/data.csv' DELIMITER ',' CSV HEADER;

Finally, verify that the data has been successfully loaded by running a simple count query:

SELECT COUNT(*) FROM sales;

Deploying SQLCoder-7B on Amazon SageMaker

If you've followed my previous article Deploy Your Own Private LLM Chatbot, simply update the code to deploy the model defog/sqlcoder-7b as follows:

locals {
    hugging_face_model_id = "defog/sqlcoder-7b"
}

Creating The LangChain Pipeline

Developing The Script

Now that your database is set up and populated with data, and your model has been deployed, we will create a simple script to create a SQLDatabaseChain using the connection to the database and the Amazon SageMaker Endpoint.
If you're new to the concept of Retrieval Augmented Generation (RAG), please refer to my previous article Create Context-Aware LLM Chatbot using Amazon Bedrock and LangChain.
Use the following Python script to enable natural language queries to be executed on the data stored in the database:

import boto3

import json
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint, LLMContentHandler
from typing import Dict
from sqlalchemy.exc import ProgrammingError

# RDS configuration
RDS_DB_NAME = "<RDS_DB_NAME>"
RDS_ENDPOINT = "<RDS_ENDPOINT>"
RDS_USERNAME = "<RDS_USERNAME>"
RDS_PASSWORD = "<RDS_PASSWORD>"
RDS_PORT = "<RDS_PORT>"
RDS_URI = f"postgresql+psycopg2://{RDS_USERNAME}:{RDS_PASSWORD}@{RDS_ENDPOINT}:{RDS_PORT}/{RDS_DB_NAME}"
db = SQLDatabase.from_uri(
    RDS_URI,
    include_tables=["sales"],
    sample_rows_in_table_info=2,
)

# Sagemaker configuration
SAGEMAKER_ENDPOINT_NAME = "<SAGEMAKER_ENDPOINT_NAME>"
MAX_TOKENS = 1024

class ContentHandler(LLMContentHandler):
    content_type = "application/json"
    accepts = "application/json"

    def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
        input_str = json.dumps({"inputs": prompt.strip(), "parameters": model_kwargs})
        return input_str.encode("utf-8")

    def transform_output(self, output: bytes) -> str:
        response_json = json.loads(output.read().decode("utf-8"))
        response = response_json[0]["generated_text"].strip().split("\n")[0]
        return response

content_handler = ContentHandler()

sagemaker_client = boto3.client("runtime.sagemaker")
llm = SagemakerEndpoint(
    client=sagemaker_client,
    endpoint_name=SAGEMAKER_ENDPOINT_NAME,
    model_kwargs={
        "max_new_tokens": MAX_TOKENS,
        "return_full_text": False,
    },
    content_handler=content_handler,
)

# Chain
db_chain = SQLDatabaseChain.from_llm(
    llm,
    db,
    verbose=True,
)

while True:
    user_input = input("Enter a message (or 'exit' to quit): ")

    if user_input.lower() == "exit":
        break

    try:
        results = db_chain.run(user_input)
        print(results)
    except (ProgrammingError, ValueError) as exc:
        print(f"\n\n{exc}")

When running this script, it will prompt the user to enter messages, which are then passed through the SQLDatabaseChain.
This script is for demonstration purposes only, and can be customized further.
LangChain offers flexibility to customize prompts to have a better results when interacting with the LLM.
Another way to customize the script is to provide detailed table definitions to the LLM.
By doing so, you can provide additional context about the structure of the tables being queried, which can help the LLM generate more accurate and relevant SQL queries.

Testing the script

Let's see what's happening under the hood.
I run the script and entered the question What is the most sold product ?.
Here we can see that LangChain has generated a prompt that contains the table schema and 2 rows of the table.
The prompt then requested Amazon SageMaker endpoint to create an SQL query based on the given context.

sql_request

As a result the model returned the generated SQL query as shown in the image below: sql_query

LangChain executed the query against the database and produced the following results: sql_result

With these results in hand, LangChain requested the LLM model again, filling the SQLResult in the prompt and requesting an Answer.
The model responded with a final answer, which can be seen in the following image: final_answer

From this successful outcome, it appears that the script is capable of handling natural language queries.
To further test the script, consider trying out additional complex queries.
It's essential to consider to use a read-only user for the database connection, as the LLM can potentially generate Data Manipulation Language (DML) queries like insert, delete, and alter.

Conclusion

Such integration opens the door to a wide range of applications, from data analysis to chatbots capable of answering complex database-related inquiries. In case you encounter any challenges during the integration process or have specific requirements that aren't covered in this guide, feel free to reach me out. I'm here to help!

Disclaimer

The information presented in this article is intended for informational purposes only. I assume no responsibility or liability for any use, misuse, or interpretation of the information contained herein.