Introduction
In this article, we will explore how to use SQLCoder-7B, a Large Language Model (LLM) that we will deploy on Amazon SageMaker, along with LangChain to perform Natural Language Querying (NLQ).
We will see how to use LangChain to create a pipeline that prompts the LLM to generate an SQL query, retrieves data from a PostgreSQL database, and passes the results as a context to the LLM to obtain the final response.
SQLCoder
SQLCoder is a collection of Large Language Models (LLMs) created for an efficient generation of SQL queries from natural language.
We will use SQLCoder-7B, which is based on Mistral-7B and has been fine-tuned for SQL queries generation.
According to its creators: "SQLCoder-7B outperforms GPT-3.5 Turbo and other popular open-source models in natural language to SQL tasks. Additionally, it even surpasses GPT-4 when fine-tuned on a specific database schema."
Setting Up The Environment
Provisioning the Database
Let's begin by provisioning a PostgreSQL database using Amazon RDS.
We will use the following Terraform code snippet to accomplish this task:
# ------------------------------------------------------------------------------
# RDS Security group
# ------------------------------------------------------------------------------
resource "aws_security_group" "db_sg" {
name_prefix = local.db_security_group_name_prefix
vpc_id = local.vpc_id
ingress {
from_port = local.db_port
to_port = local.db_port
protocol = "tcp"
cidr_blocks = [local.my_ip_address]
}
egress {
from_port = 0
to_port = 0
protocol = "-1"
cidr_blocks = ["0.0.0.0/0"]
}
}
# ------------------------------------------------------------------------------
# RDS
# ------------------------------------------------------------------------------
module "db" {
source = "terraform-aws-modules/rds/aws"
identifier = local.db_identifier
engine = "postgres"
engine_version = "15.4"
family = "postgres15"
instance_class = local.db_instance_class
allocated_storage = local.db_allocated_storage
db_name = local.db_name
username = local.db_username
port = local.db_port
create_db_subnet_group = true
vpc_security_group_ids = [aws_security_group.db_sg.id]
subnet_ids = local.db_subnet_ids
}
Initializing the Database
Now, let's populate our newly created database with data from the E-Commerce Data dataset, which you can download as a CSV file from Kaggle.
After downloading the dataset, connect to your RDS database using the following command:
psql -h <RDS_ENDPOINT> -p <RDS_PORT> -U <DATABASE_USERNAME> -d <DATABASE_NAME> -W
Enter your password when prompted. Once connected, create a sales
table with the following SQL command:
CREATE TABLE sales (
invoiceno VARCHAR(255),
stockcode VARCHAR(255),
description VARCHAR(255),
quantity INT,
invoicedate TIMESTAMP,
unitprice DECIMAL(10, 2),
customerid INT,
country VARCHAR(50)
);
Next, copy the data from the CSV file into the sales
table using the following command:
\COPY sales(invoiceno, stockcode, description, quantity, invoicedate, unitprice, customerid, country) FROM '/path/to/data.csv' DELIMITER ',' CSV HEADER;
Finally, verify that the data has been successfully loaded by running a simple count query:
SELECT COUNT(*) FROM sales;
Deploying SQLCoder-7B on Amazon SageMaker
If you've followed my previous article Deploy Your Own Private LLM Chatbot, simply update the code to deploy the model defog/sqlcoder-7b
as follows:
locals {
hugging_face_model_id = "defog/sqlcoder-7b"
}
Creating The LangChain Pipeline
Developing The Script
Now that your database is set up and populated with data, and your model has been deployed, we will create a simple script to create a SQLDatabaseChain
using the connection to the database and the Amazon SageMaker Endpoint.
If you're new to the concept of Retrieval Augmented Generation (RAG), please refer to my previous article Create Context-Aware LLM Chatbot using Amazon Bedrock and LangChain.
Use the following Python script to enable natural language queries to be executed on the data stored in the database:
import boto3
import json
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint, LLMContentHandler
from typing import Dict
from sqlalchemy.exc import ProgrammingError
# RDS configuration
RDS_DB_NAME = "<RDS_DB_NAME>"
RDS_ENDPOINT = "<RDS_ENDPOINT>"
RDS_USERNAME = "<RDS_USERNAME>"
RDS_PASSWORD = "<RDS_PASSWORD>"
RDS_PORT = "<RDS_PORT>"
RDS_URI = f"postgresql+psycopg2://{RDS_USERNAME}:{RDS_PASSWORD}@{RDS_ENDPOINT}:{RDS_PORT}/{RDS_DB_NAME}"
db = SQLDatabase.from_uri(
RDS_URI,
include_tables=["sales"],
sample_rows_in_table_info=2,
)
# Sagemaker configuration
SAGEMAKER_ENDPOINT_NAME = "<SAGEMAKER_ENDPOINT_NAME>"
MAX_TOKENS = 1024
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": prompt.strip(), "parameters": model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
response = response_json[0]["generated_text"].strip().split("\n")[0]
return response
content_handler = ContentHandler()
sagemaker_client = boto3.client("runtime.sagemaker")
llm = SagemakerEndpoint(
client=sagemaker_client,
endpoint_name=SAGEMAKER_ENDPOINT_NAME,
model_kwargs={
"max_new_tokens": MAX_TOKENS,
"return_full_text": False,
},
content_handler=content_handler,
)
# Chain
db_chain = SQLDatabaseChain.from_llm(
llm,
db,
verbose=True,
)
while True:
user_input = input("Enter a message (or 'exit' to quit): ")
if user_input.lower() == "exit":
break
try:
results = db_chain.run(user_input)
print(results)
except (ProgrammingError, ValueError) as exc:
print(f"\n\n{exc}")
When running this script, it will prompt the user to enter messages, which are then passed through the SQLDatabaseChain
.
This script is for demonstration purposes only, and can be customized further.
LangChain offers flexibility to customize prompts to have a better results when interacting with the LLM.
Another way to customize the script is to provide detailed table definitions to the LLM.
By doing so, you can provide additional context about the structure of the tables being queried, which can help the LLM generate more accurate and relevant SQL queries.
Testing the script
Let's see what's happening under the hood.
I run the script and entered the question What is the most sold product ?
.
Here we can see that LangChain has generated a prompt that contains the table schema and 2 rows of the table.
The prompt then requested Amazon SageMaker endpoint to create an SQL query based on the given context.
As a result the model returned the generated SQL query as shown in the image below:
LangChain executed the query against the database and produced the following results:
With these results in hand, LangChain requested the LLM model again, filling the SQLResult
in the prompt and requesting an Answer
.
The model responded with a final answer, which can be seen in the following image:
From this successful outcome, it appears that the script is capable of handling natural language queries.
To further test the script, consider trying out additional complex queries.
It's essential to consider to use a read-only user for the database connection, as the LLM can potentially generate Data Manipulation Language (DML) queries like insert, delete, and alter.
Conclusion
Such integration opens the door to a wide range of applications, from data analysis to chatbots capable of answering complex database-related inquiries. In case you encounter any challenges during the integration process or have specific requirements that aren't covered in this guide, feel free to reach me out. I'm here to help!
Disclaimer
The information presented in this article is intended for informational purposes only. I assume no responsibility or liability for any use, misuse, or interpretation of the information contained herein.