#!/usr/bin/env python3
"""
Google Ads Agency Agent - Pattern Learning
Learns what works for YOUR specific account over time.
"""

import sys
import argparse
from datetime import datetime, timedelta
from typing import Dict, List, Any
import json
from collections import defaultdict

from core import (
    get_client, ACCOUNTS, run_query, DATA_DIR,
    micros_to_currency, format_currency,
)


def date_range_for_days(days: int) -> str:
    """Generate a valid GAQL date range string."""
    # Google Ads API only supports LAST_7_DAYS, LAST_14_DAYS, LAST_30_DAYS
    # For other ranges, use explicit dates
    from datetime import timedelta
    end_date = datetime.now().date() - timedelta(days=1)  # Yesterday
    start_date = end_date - timedelta(days=days)
    return f"'{start_date}' AND '{end_date}'"


def analyze_time_patterns(client, account_key: str, days: int = 90) -> Dict:
    """
    Learn which hours and days perform best.
    """
    account = ACCOUNTS[account_key]
    date_range = date_range_for_days(days)
    
    # Hour of day analysis
    hour_query = f"""
        SELECT
            segments.hour,
            segments.day_of_week,
            metrics.cost_micros,
            metrics.conversions,
            metrics.clicks
        FROM campaign
        WHERE segments.date BETWEEN {date_range}
          AND campaign.status = 'ENABLED'
    """
    
    results = run_query(client, account['id'], hour_query)
    
    # Aggregate by hour and day
    hour_data = defaultdict(lambda: {'cost': 0, 'conversions': 0, 'clicks': 0})
    day_data = defaultdict(lambda: {'cost': 0, 'conversions': 0, 'clicks': 0})
    
    for row in results:
        hour = row.segments.hour
        day = row.segments.day_of_week.name
        
        cost = micros_to_currency(row.metrics.cost_micros)
        conv = row.metrics.conversions
        
        hour_data[hour]['cost'] += cost
        hour_data[hour]['conversions'] += conv
        hour_data[hour]['clicks'] += row.metrics.clicks
        
        day_data[day]['cost'] += cost
        day_data[day]['conversions'] += conv
        day_data[day]['clicks'] += row.metrics.clicks
    
    # Calculate CPA per hour
    hour_performance = {}
    for hour, data in hour_data.items():
        hour_performance[hour] = {
            'cost': data['cost'],
            'conversions': data['conversions'],
            'cpa': data['cost'] / data['conversions'] if data['conversions'] > 0 else float('inf'),
            'conv_rate': data['conversions'] / data['clicks'] * 100 if data['clicks'] > 0 else 0,
        }
    
    # Find best and worst hours
    sorted_hours = sorted(
        [(h, p) for h, p in hour_performance.items() if p['conversions'] > 0],
        key=lambda x: x[1]['cpa']
    )
    
    best_hours = [h for h, _ in sorted_hours[:6]]  # Top 6 hours by CPA
    worst_hours = [h for h, _ in sorted_hours[-3:]] if len(sorted_hours) > 3 else []
    
    # Day of week analysis
    day_performance = {}
    for day, data in day_data.items():
        day_performance[day] = {
            'cost': data['cost'],
            'conversions': data['conversions'],
            'cpa': data['cost'] / data['conversions'] if data['conversions'] > 0 else float('inf'),
        }
    
    sorted_days = sorted(
        [(d, p) for d, p in day_performance.items() if p['conversions'] > 0],
        key=lambda x: x[1]['cpa']
    )
    
    best_days = [d for d, _ in sorted_days[:3]]
    
    return {
        'hour_performance': hour_performance,
        'day_performance': day_performance,
        'best_hours': best_hours,
        'worst_hours': worst_hours,
        'best_days': best_days,
    }


def analyze_keyword_patterns(client, account_key: str, days: int = 90) -> Dict:
    """
    Learn which keyword patterns work best.
    """
    account = ACCOUNTS[account_key]
    date_range = date_range_for_days(days)
    
    query = f"""
        SELECT
            ad_group_criterion.keyword.text,
            ad_group_criterion.keyword.match_type,
            metrics.cost_micros,
            metrics.conversions,
            metrics.clicks
        FROM keyword_view
        WHERE segments.date BETWEEN {date_range}
          AND ad_group_criterion.status = 'ENABLED'
          AND metrics.conversions > 0
        ORDER BY metrics.conversions DESC
        LIMIT 100
    """
    
    results = run_query(client, account['id'], query)
    
    # Extract patterns from top converting keywords
    patterns = defaultdict(lambda: {'count': 0, 'conversions': 0, 'cost': 0})
    
    for row in results:
        keyword = row.ad_group_criterion.keyword.text.lower()
        conv = row.metrics.conversions
        cost = micros_to_currency(row.metrics.cost_micros)
        
        # Extract word patterns
        words = keyword.split()
        for word in words:
            if len(word) > 2:  # Skip short words
                patterns[word]['count'] += 1
                patterns[word]['conversions'] += conv
                patterns[word]['cost'] += cost
        
        # Match type analysis
        match_type = row.ad_group_criterion.keyword.match_type.name
        patterns[f"match_{match_type}"]['count'] += 1
        patterns[f"match_{match_type}"]['conversions'] += conv
        patterns[f"match_{match_type}"]['cost'] += cost
    
    # Find top patterns
    top_patterns = sorted(
        [(p, d) for p, d in patterns.items() if d['count'] >= 3 and not p.startswith('match_')],
        key=lambda x: x[1]['conversions'],
        reverse=True
    )[:20]
    
    # Match type performance
    match_types = {p: d for p, d in patterns.items() if p.startswith('match_')}
    
    return {
        'top_converting_patterns': [(p, d) for p, d in top_patterns],
        'match_type_performance': match_types,
    }


def analyze_device_patterns(client, account_key: str, days: int = 90) -> Dict:
    """
    Learn device performance patterns.
    """
    account = ACCOUNTS[account_key]
    date_range = date_range_for_days(days)
    
    query = f"""
        SELECT
            segments.device,
            metrics.cost_micros,
            metrics.conversions,
            metrics.clicks
        FROM campaign
        WHERE segments.date BETWEEN {date_range}
          AND campaign.status = 'ENABLED'
    """
    
    results = run_query(client, account['id'], query)
    
    device_data = defaultdict(lambda: {'cost': 0, 'conversions': 0, 'clicks': 0})
    
    for row in results:
        device = row.segments.device.name
        device_data[device]['cost'] += micros_to_currency(row.metrics.cost_micros)
        device_data[device]['conversions'] += row.metrics.conversions
        device_data[device]['clicks'] += row.metrics.clicks
    
    device_performance = {}
    total_cost = sum(d['cost'] for d in device_data.values())
    total_conv = sum(d['conversions'] for d in device_data.values())
    avg_cpa = total_cost / total_conv if total_conv > 0 else 0
    
    for device, data in device_data.items():
        cpa = data['cost'] / data['conversions'] if data['conversions'] > 0 else 0
        device_performance[device] = {
            'cost': data['cost'],
            'conversions': data['conversions'],
            'cpa': cpa,
            'cost_share': data['cost'] / total_cost * 100 if total_cost > 0 else 0,
            'vs_average': (cpa - avg_cpa) / avg_cpa * 100 if avg_cpa > 0 else 0,
        }
    
    return device_performance


def analyze_seasonal_patterns(client, account_key: str) -> Dict:
    """
    Learn seasonal/monthly patterns (requires historical data).
    """
    account = ACCOUNTS[account_key]
    date_range = date_range_for_days(365)
    
    # Get monthly data (up to 12 months if available)
    query = f"""
        SELECT
            segments.month,
            metrics.cost_micros,
            metrics.conversions,
            metrics.clicks
        FROM customer
        WHERE segments.date BETWEEN {date_range}
    """
    
    results = run_query(client, account['id'], query)
    
    month_data = defaultdict(lambda: {'cost': 0, 'conversions': 0, 'clicks': 0})
    
    for row in results:
        month = row.segments.month  # Format: "2025-01"
        month_data[month]['cost'] += micros_to_currency(row.metrics.cost_micros)
        month_data[month]['conversions'] += row.metrics.conversions
        month_data[month]['clicks'] += row.metrics.clicks
    
    monthly_performance = {}
    for month, data in sorted(month_data.items()):
        monthly_performance[month] = {
            'cost': data['cost'],
            'conversions': data['conversions'],
            'cpa': data['cost'] / data['conversions'] if data['conversions'] > 0 else 0,
        }
    
    return monthly_performance


def generate_learnings(client, account_key: str) -> Dict:
    """Generate comprehensive learnings for an account."""
    return {
        'time_patterns': analyze_time_patterns(client, account_key),
        'keyword_patterns': analyze_keyword_patterns(client, account_key),
        'device_patterns': analyze_device_patterns(client, account_key),
        'seasonal_patterns': analyze_seasonal_patterns(client, account_key),
        'analyzed_at': datetime.now().isoformat(),
    }


def save_patterns(account_key: str, patterns: Dict):
    """Save learned patterns."""
    patterns_file = DATA_DIR / "patterns.json"
    
    all_patterns = {}
    if patterns_file.exists():
        with open(patterns_file) as f:
            all_patterns = json.load(f)
    
    all_patterns[account_key] = patterns
    
    with open(patterns_file, 'w') as f:
        json.dump(all_patterns, f, indent=2, default=str)


def format_insights(account_name: str, patterns: Dict) -> str:
    """Format pattern insights for display."""
    output = f"🧠 **Account Insights: {account_name}**\n\n"
    
    # Time patterns
    time = patterns.get('time_patterns', {})
    if time.get('best_hours'):
        output += "**Best Hours (by CPA):**\n"
        output += f"• {', '.join(f'{h}:00' for h in time['best_hours'])}\n\n"
    
    if time.get('best_days'):
        output += "**Best Days:**\n"
        output += f"• {', '.join(time['best_days'])}\n\n"
    
    # Device patterns
    devices = patterns.get('device_patterns', {})
    if devices:
        output += "**Device Performance:**\n"
        for device, data in devices.items():
            vs_avg = data.get('vs_average', 0)
            emoji = "✅" if vs_avg < -10 else "⚠️" if vs_avg > 10 else "➖"
            output += f"• {device}: {format_currency(data['cpa'])} CPA {emoji} ({vs_avg:+.0f}% vs avg)\n"
        output += "\n"
    
    # Keyword patterns
    kw = patterns.get('keyword_patterns', {})
    top_patterns = kw.get('top_converting_patterns', [])
    if top_patterns:
        output += "**Top Converting Keywords Contain:**\n"
        for word, data in top_patterns[:5]:
            output += f"• '{word}': {data['conversions']:.0f} conversions\n"
        output += "\n"
    
    # Match types
    match_types = kw.get('match_type_performance', {})
    if match_types:
        output += "**Match Type Performance:**\n"
        for mt, data in match_types.items():
            name = mt.replace('match_', '')
            cpa = data['cost'] / data['conversions'] if data['conversions'] > 0 else 0
            output += f"• {name}: {format_currency(cpa)} CPA ({data['conversions']:.0f} conv)\n"
    
    return output


def main():
    parser = argparse.ArgumentParser(description="Learn Google Ads patterns")
    parser.add_argument("--account", type=str, help="Specific account to analyze")
    parser.add_argument("--json", action="store_true", help="Output as JSON")
    args = parser.parse_args()
    
    client = get_client()
    
    accounts_to_analyze = ACCOUNTS.keys()
    if args.account:
        accounts_to_analyze = [k for k in accounts_to_analyze if args.account.lower() in k.lower()]
    
    all_insights = []
    
    for account_key in accounts_to_analyze:
        account = ACCOUNTS[account_key]
        print(f"Analyzing patterns for {account['name']}...", file=sys.stderr)
        
        try:
            patterns = generate_learnings(client, account_key)
            save_patterns(account_key, patterns)
            
            if args.json:
                all_insights.append({
                    'account': account_key,
                    'patterns': patterns,
                })
            else:
                print(format_insights(account['name'], patterns))
                print()
        except Exception as e:
            print(f"Error analyzing {account['name']}: {e}", file=sys.stderr)
    
    if args.json:
        print(json.dumps(all_insights, indent=2, default=str))


if __name__ == "__main__":
    main()
