Training Regression Model on 170k Songs for Real-Time DJ Mix Recommendations

Tech Stack

AWS Sagemaker
Python
ML
AWS Lambda
Spotify API
React
MongoDB

Built regression model on AWS Sagemaker analyzing BPM, energy, valence across 170k tracks to predict 'mixing scores' between songs. Deployed real-time inference via Lambda + API Gateway, enabling DJs to find compatible transitions in <500ms. Integrated Spotify API for genre classification.

Article

Project Background

I embarked on this project fueled by my newfound passion for DJing. As I delved into this realm, I recognized the multifaceted challenges DJs face — juggling BPM, valence, energy, and more. It seemed navigating these variables was a labyrinthine task. I also noticed the absence of dedicated platforms tailored specifically for DJs.

My vision was clear: to craft a comprehensive site, a hub exclusively for DJs. A platform not just for sharing remixes but also a thriving social network to amplify their reach. Moreover, I envisioned an intelligent recommender system and an array of tools and resources aimed at empowering DJs to enhance their skills.

The Technical Challenge

The core problem: How do you predict which songs will mix well together based on their musical features?

This required:

Architecture Overview

┌─────────────┐
│   React UI  │
└──────┬──────┘
       │
       ▼
┌─────────────┐     ┌──────────────┐
│  Spotify    │────▶│  Feature     │
│  API        │     │  Extraction  │
└─────────────┘     └──────┬───────┘
                           │
                           ▼
                    ┌──────────────┐
                    │  AWS Lambda  │
                    │  (Inference) │
                    └──────┬───────┘
                           │
                           ▼
┌─────────────┐     ┌──────────────┐
│  MongoDB    │◀────│  Sagemaker   │
│  (170k)     │     │  Model       │
└─────────────┘     └──────────────┘

Part 1: Data Collection & Processing

Dataset Acquisition

I sourced a comprehensive dataset containing 170,000 songs, with each entry enriched by 8 features:

Feature Engineering Script

import pandas as pd
import spotipy
from spotipy.oauth2 import SpotifyClientCredentials
 
# Initialize Spotify API
sp = spotipy.Spotify(auth_manager=SpotifyClientCredentials(
    client_id='YOUR_CLIENT_ID',
    client_secret='YOUR_CLIENT_SECRET'
))
 
def extract_audio_features(track_id):
    """Extract audio features for a single track"""
    features = sp.audio_features(track_id)[0]
    
    return {
        'id': track_id,
        'bpm': features['tempo'],
        'energy': features['energy'],
        'valence': features['valence'],
        'danceability': features['danceability'],
        'acousticness': features['acousticness'],
        'instrumentalness': features['instrumentalness'],
        'liveness': features['liveness'],
        'speechiness': features['speechiness']
    }
 
def get_genre(track_id):
    """Retrieve genre for a track via artist lookup"""
    track = sp.track(track_id)
    artist_id = track['artists'][0]['id']
    artist = sp.artist(artist_id)
    return artist['genres'][0] if artist['genres'] else 'unknown'
 
# Process entire dataset
df = pd.read_csv('raw_tracks.csv')
df['features'] = df['track_id'].apply(extract_audio_features)
df['genre'] = df['track_id'].apply(get_genre)
 
# Normalize features for model training
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
df[['bpm', 'energy', 'valence', 'danceability']] = scaler.fit_transform(
    df[['bpm', 'energy', 'valence', 'danceability']]
)
 
df.to_csv('processed_170k_tracks.csv', index=False)
print(f"Processed {len(df)} tracks")

Part 2: Model Training with AWS Sagemaker

August 8th, 11:03 PM - Initial Setup

Sagemaker kickstarted Mixmeister's training pipeline. I began with regression for the mixability score, considering future classification via Spotify's user preferences.

Training Script

import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.sklearn.estimator import SKLearn
 
# Initialize Sagemaker session
sagemaker_session = sagemaker.Session()
role = get_execution_role()
bucket = 'djdoc-ml-models'
 
# Upload training data to S3
train_data = sagemaker_session.upload_data(
    path='processed_170k_tracks.csv',
    bucket=bucket,
    key_prefix='training'
)
 
# Define training script
sklearn_estimator = SKLearn(
    entry_point='train_mixing_model.py',
    role=role,
    instance_type='ml.m5.xlarge',
    framework_version='0.23-1',
    hyperparameters={
        'estimator': 'RandomForest',
        'n_estimators': 200,
        'max_depth': 20,
        'min_samples_split': 5
    }
)
 
# Start training job
sklearn_estimator.fit({'train': train_data})

The Mixing Score Algorithm

The core of the model: predicting how well two songs mix based on feature similarity.

# train_mixing_model.py
import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
import joblib
 
def calculate_mixing_score(song1_features, song2_features):
    """
    Calculate mixing compatibility score between two songs
    Score ranges from 0 (poor mix) to 100 (perfect mix)
    """
    # BPM difference (most critical)
    bpm_diff = abs(song1_features['bpm'] - song2_features['bpm'])
    bpm_score = max(0, 100 - (bpm_diff * 2))  # Penalize BPM difference heavily
    
    # Energy compatibility
    energy_diff = abs(song1_features['energy'] - song2_features['energy'])
    energy_score = max(0, 100 - (energy_diff * 100))
    
    # Valence (mood) compatibility
    valence_diff = abs(song1_features['valence'] - song2_features['valence'])
    valence_score = max(0, 100 - (valence_diff * 80))
    
    # Weighted combination
    mixing_score = (
        bpm_score * 0.4 +
        energy_score * 0.3 +
        valence_score * 0.3
    )
    
    return mixing_score
 
# Generate training pairs
def generate_training_pairs(df):
    """Create pairwise combinations for training"""
    pairs = []
    for i in range(len(df)):
        for j in range(i+1, min(i+100, len(df))):  # Sample 100 pairs per song
            score = calculate_mixing_score(df.iloc[i], df.iloc[j])
            pairs.append({
                'features': np.concatenate([
                    df.iloc[i][['bpm', 'energy', 'valence', 'danceability']].values,
                    df.iloc[j][['bpm', 'energy', 'valence', 'danceability']].values
                ]),
                'score': score
            })
    return pairs
 
# Train model
df = pd.read_csv('processed_170k_tracks.csv')
training_data = generate_training_pairs(df)
 
X = np.array([p['features'] for p in training_data])
y = np.array([p['score'] for p in training_data])
 
model = RandomForestRegressor(n_estimators=200, max_depth=20, random_state=42)
model.fit(X, y)
 
# Evaluate
predictions = model.predict(X)
print(f"R² Score: {r2_score(y, predictions):.3f}")
print(f"RMSE: {np.sqrt(mean_squared_error(y, predictions)):.3f}")
 
# Save model
joblib.dump(model, '/opt/ml/model/mixing_model.pkl')

August 10th, 12:40 AM - Algorithm Optimization

Shifted from random forest classification to regression algorithm for enhanced performance. The regression approach provided continuous mixing scores rather than discrete categories, giving DJs more nuanced recommendations.

Part 3: Deployment Pipeline

August 14th, 1:48 AM - Model Endpoint Deployment

Mixmeister went live! The goal: generating mixing scores for searched songs and suggesting 3–5 compatible options.

Lambda Function for Real-Time Inference

# lambda_handler.py
import json
import boto3
import numpy as np
 
# Initialize Sagemaker runtime
runtime = boto3.client('sagemaker-runtime')
endpoint_name = 'djdoc-mixing-model'
 
def lambda_handler(event, context):
    """
    Handle mixing score prediction requests
    Input: Current song features + desired genre
    Output: Top 5 compatible songs
    """
    try:
        # Parse request
        body = json.loads(event['body'])
        current_song_features = body['song_features']
        target_genre = body.get('genre', 'all')
        
        # Query MongoDB for candidate songs
        candidates = get_candidate_songs(target_genre)
        
        # Prepare features for batch prediction
        feature_pairs = []
        for candidate in candidates:
            pair = np.concatenate([
                current_song_features,
                candidate['features']
            ])
            feature_pairs.append(pair)
        
        # Invoke Sagemaker endpoint
        payload = json.dumps({'instances': feature_pairs})
        response = runtime.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType='application/json',
            Body=payload
        )
        
        # Parse predictions
        predictions = json.loads(response['Body'].read())['predictions']
        
        # Rank and return top 5
        scored_songs = [
            {'song': candidates[i], 'score': predictions[i]}
            for i in range(len(candidates))
        ]
        scored_songs.sort(key=lambda x: x['score'], reverse=True)
        top_5 = scored_songs[:5]
        
        return {
            'statusCode': 200,
            'headers': {
                'Content-Type': 'application/json',
                'Access-Control-Allow-Origin': '*'
            },
            'body': json.dumps(top_5)
        }
        
    except Exception as e:
        print(f"Error: {str(e)}")
        return {
            'statusCode': 500,
            'body': json.dumps({'error': str(e)})
        }
 
def get_candidate_songs(genre):
    """Query MongoDB for songs in target genre"""
    from pymongo import MongoClient
    
    client = MongoClient(os.environ['MONGO_URI'])
    db = client['djdoc']
    
    query = {} if genre == 'all' else {'genre': genre}
    songs = list(db.tracks.find(query).limit(100))
    
    return songs

August 23rd, 2:18 PM - CORS Fix & MongoDB Integration

Training to deployment via endpoints, Lambda functions, and API Gateways was smooth. But a CORS issue needed fixing.

// API Gateway CORS Configuration
module.exports.handler = async (event) => {
  // Handle preflight OPTIONS request
  if (event.requestContext.http.method === 'OPTIONS') {
    return {
      statusCode: 200,
      headers: {
        'Access-Control-Allow-Origin': '*',
        'Access-Control-Allow-Methods': 'GET, POST, OPTIONS',
        'Access-Control-Allow-Headers': 'Content-Type, Authorization',
      },
      body: '',
    };
  }
  
  // Forward to Lambda
  // ... rest of handler
};

MongoDB Schema Design

// MongoDB Track Schema
const trackSchema = {
  _id: ObjectId,
  track_id: String,  // Spotify ID
  name: String,
  artist: String,
  album: String,
  genre: String,
  features: {
    bpm: Number,
    energy: Number,
    valence: Number,
    danceability: Number,
    acousticness: Number,
    instrumentalness: Number,
    liveness: Number,
    speechiness: Number
  },
  mixing_scores: [{  // Pre-computed scores for fast lookup
    with_track_id: String,
    score: Number
  }],
  created_at: Date
};
 
// Indexes for fast queries
db.tracks.createIndex({ genre: 1, "features.bpm": 1 });
db.tracks.createIndex({ "features.energy": 1, "features.valence": 1 });

Part 4: Frontend Implementation

React Component for Song Search

// MixMeister.tsx
import React, { useState } from 'react';
import axios from 'axios';
 
interface Song {
  id: string;
  name: string;
  artist: string;
  bpm: number;
  mixingScore: number;
}
 
export const MixMeister: React.FC = () => {
  const [currentSong, setCurrentSong] = useState<Song | null>(null);
  const [genre, setGenre] = useState<string>('house');
  const [recommendations, setRecommendations] = useState<Song[]>([]);
  const [loading, setLoading] = useState(false);
 
  const searchSong = async (query: string) => {
    // Search Spotify
    const response = await axios.get('/api/search', {
      params: { q: query }
    });
    setCurrentSong(response.data.tracks[0]);
  };
 
  const getMixingRecommendations = async () => {
    if (!currentSong) return;
    
    setLoading(true);
    try {
      const response = await axios.post(
        'https://api.djdoc.com/mixing-score',
        {
          song_features: currentSong.features,
          genre: genre
        }
      );
      
      setRecommendations(response.data);
    } catch (error) {
      console.error('Error getting recommendations:', error);
    } finally {
      setLoading(false);
    }
  };
 
  return (
    <div className="mixmeister-container">
      <h1>MixMeister</h1>
      
      {/* Song Search */}
      <input
        type="text"
        placeholder="Search for your current song..."
        onChange={(e) => searchSong(e.target.value)}
      />
      
      {/* Genre Selection */}
      <select value={genre} onChange={(e) => setGenre(e.target.value)}>
        <option value="house">House</option>
        <option value="techno">Techno</option>
        <option value="hip-hop">Hip Hop</option>
        <option value="all">All Genres</option>
      </select>
      
      {/* Current Song Display */}
      {currentSong && (
        <div className="current-song">
          <h3>{currentSong.name}</h3>
          <p>{currentSong.artist}</p>
          <p>BPM: {currentSong.bpm}</p>
          <button onClick={getMixingRecommendations}>
            Find Compatible Tracks
          </button>
        </div>
      )}
      
      {/* Recommendations */}
      {loading ? (
        <p>Finding the perfect mix...</p>
      ) : (
        <div className="recommendations">
          {recommendations.map((song) => (
            <div key={song.id} className="recommendation-card">
              <h4>{song.name}</h4>
              <p>{song.artist}</p>
              <div className="mixing-score">
                Mixing Score: {song.mixingScore.toFixed(1)}/100
              </div>
            </div>
          ))}
        </div>
      )}
    </div>
  );
};

Part 5: Firebase Authentication

By integrating Firebase Authentication, I regulated access to specific pages post user sign-up and login, heightening platform security.

// auth.ts
import { initializeApp } from 'firebase/app';
import { getAuth, signInWithEmailAndPassword, createUserWithEmailAndPassword } from 'firebase/auth';
 
const firebaseConfig = {
  apiKey: process.env.NEXT_PUBLIC_FIREBASE_API_KEY,
  authDomain: "djdoc-platform.firebaseapp.com",
  projectId: "djdoc-platform",
};
 
const app = initializeApp(firebaseConfig);
export const auth = getAuth(app);
 
export const signUp = async (email: string, password: string) => {
  try {
    const userCredential = await createUserWithEmailAndPassword(auth, email, password);
    return userCredential.user;
  } catch (error) {
    console.error('Sign up error:', error);
    throw error;
  }
};
 
export const signIn = async (email: string, password: string) => {
  try {
    const userCredential = await signInWithEmailAndPassword(auth, email, password);
    return userCredential.user;
  } catch (error) {
    console.error('Sign in error:', error);
    throw error;
  }
};

Part 6: Spotify Converter Feature

Built a download pipeline using spotDL library for converting Spotify playlists to local files.

# spotify_converter.py
from flask import Flask, request, jsonify, send_file
import spotipy
from spotdl import Spotdl
import zipfile
import os
 
app = Flask(__name__)
 
@app.route('/api/convert-playlist', methods=['POST'])
def convert_playlist():
    """
    Convert Spotify playlist to downloadable zip
    """
    playlist_id = request.json['playlist_id']
    
    # Initialize Spotify client
    sp = spotipy.Spotify(auth_manager=SpotifyOAuth())
    
    # Get playlist tracks
    results = sp.playlist_tracks(playlist_id)
    tracks = results['items']
    
    # Download tracks using spotDL
    spotdl = Spotdl(
        client_id=os.getenv('SPOTIFY_CLIENT_ID'),
        client_secret=os.getenv('SPOTIFY_CLIENT_SECRET')
    )
    
    download_dir = f'/tmp/playlist_{playlist_id}'
    os.makedirs(download_dir, exist_ok=True)
    
    for track in tracks:
        track_url = track['track']['external_urls']['spotify']
        spotdl.download_songs([track_url], download_dir)
    
    # Zip files
    zip_path = f'/tmp/playlist_{playlist_id}.zip'
    with zipfile.ZipFile(zip_path, 'w') as zipf:
        for root, dirs, files in os.walk(download_dir):
            for file in files:
                zipf.write(os.path.join(root, file))
    
    return send_file(zip_path, as_attachment=True)
 
if __name__ == '__main__':
    app.run(debug=True)

Part 7: Performance & Cost Optimization

September 8th, 10:30 PM - Project Completion

The Regen button operates seamlessly. However, model endpoint was shutdown due to loss of $400 on AWS.

Lessons Learned:

  1. Always use endpoint autoscaling limits
# Sagemaker endpoint config with limits
endpoint_config = {
    'VariantName': 'AllTraffic',
    'ModelName': 'djdoc-mixing-model',
    'InitialInstanceCount': 1,
    'InstanceType': 'ml.t2.medium',  # Cheaper instance
    'AutoScalingPolicy': {
        'MinCapacity': 0,  # Scale to zero when idle
        'MaxCapacity': 2
    }
}
  1. Implement caching to reduce API calls
// Redis caching layer
import Redis from 'ioredis';
 
const redis = new Redis(process.env.REDIS_URL);
 
async function getCachedRecommendations(songId: string, genre: string) {
  const cacheKey = `recs:${songId}:${genre}`;
  const cached = await redis.get(cacheKey);
  
  if (cached) {
    return JSON.parse(cached);
  }
  
  // If not cached, call Lambda
  const recommendations = await getRecommendations(songId, genre);
  
  // Cache for 1 hour
  await redis.setex(cacheKey, 3600, JSON.stringify(recommendations));
  
  return recommendations;
}
  1. Pre-compute mixing scores for common pairs

Final Architecture & Metrics

System Performance

Tech Stack Summary

ML Pipeline:

Backend:

Frontend:

DevOps:

Key Takeaways

  1. Feature engineering is critical - The mixing score algorithm needed careful weighting of BPM vs energy vs valence
  2. Cost monitoring is essential - Always set up billing alerts and endpoint limits
  3. Caching saves money - Redis reduced API calls by 80%
  4. Real-time ML inference is hard - <500ms latency required careful optimization

What's Next

Future enhancements planned: