talk2data / pandasai_visualization.py
amirkiarafiei's picture
Enhance visualization support and update project documentation
09e2bc4
raw
history blame
2.89 kB
#!/usr/bin/env python3
"""
Visualization script using PandasAI.
This script creates a sample dataframe and uses PandasAI to generate
and save visualizations based on user queries.
Usage:
python visualize.py "Create a bar chart of sales by region"
Requirements:
- pandas
- pandasai
- matplotlib
"""
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import pandasai as pai
from dotenv import load_dotenv
def create_sample_dataframe():
"""Create a sample dataframe with sales data."""
# 'Region': ['North', 'South', 'East', 'West', 'North', 'South', 'East', 'West'],
# 'Product': ['Widget', 'Widget', 'Widget', 'Widget', 'Gadget', 'Gadget', 'Gadget', 'Gadget'],
# 'Sales': [150, 200, 120, 180, 90, 110, 95, 130],
# 'Quarter': ['Q1', 'Q1', 'Q1', 'Q1', 'Q2', 'Q2', 'Q2', 'Q2'],
data = {
'Year': [2023, 2023, 2023, 2023, 2023, 2023, 2023, 2023]
}
return pai.DataFrame(data)
def visualize_data(df, query):
"""
Generate visualization based on user query using PandasAI.
Args:
df: Pandas DataFrame containing the data
query: User query string describing the desired visualization
Returns:
Path to the saved visualization file
"""
# Initialize PandasAI with an LLM
# Note: In a real application, you would need to set up your OpenAI API key
# Either set OPENAI_API_KEY environment variable or pass it directly
try:
# llm = OpenAI(api_token=api_key)
# pandas_ai = PandasAI(llm)
load_dotenv()
pai.api_key.set(os.environ["PANDAS_KEY"])
df.chat(query)
# Generate the visualization
print(f"Generating visualization for query: '{query}'")
# Save the current figure
output_file = "visualization_output.png"
plt.savefig(output_file)
plt.close()
print(f"Visualization saved to {output_file}")
return output_file
except Exception as e:
print(f"Error generating visualization: {str(e)}")
return None
def main():
"""Main function to run the visualization script."""
# Get query from command line argument
# if len(sys.argv) < 2:
# print("Usage: python visualize.py \"Your visualization query here\"")
# print("Example: python visualize.py \"Create a bar chart of sales by region\"")
# return
# query = sys.argv[1]
query = "Plot a bar chart of sales by region"
# Create sample dataframe
df = create_sample_dataframe()
print("Sample DataFrame created:")
print(df.head())
# Generate and save visualization
output_file = visualize_data(df, query)
if output_file:
print(f"Visualization process completed. Output saved to: {output_file}")
else:
print("Visualization process failed.")
if __name__ == "__main__":
main()