ML & DL

[optuna] hyper-parameter 최적화 하기

ssun-g 2022. 2. 24. 19:19
hyper-parameter 최적화를 도와주는 프레임워크.
매 Trial 마다 parameter를 변경하면서 최적의 parameter를 찾는다.

 

MNIST.ipynb
0.07MB

전체 코드는 첨부파일 참조

 

optuna 설치

pip install optuna

 

사용법

  • Library import
    import matplotlib.pyplot as plt
    
    import optuna
    from optuna.trial import TrialState
    
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import Dataset, DataLoader
    
    import sklearn.datasets

 

  • MNIST dataset download
    mnist = sklearn.datasets.fetch_openml('mnist_784', data_home='mnist_784')
  • dataset 정의
    class CustomDataset(Dataset):
        def __init__(self, images, labels):
            self.images = images
            self.labels = labels
    
        def __getitem__(self, index):
            image = self.images[index].reshape((1, 28, 28))
            return image, self.labels[index]
    
        def __len__(self):
            return len(self.images)

 

  • model 정의
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv_1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)  # (28, 28, 32)
            self.relu_1 = nn.ReLU()
            self.pool_1 = nn.MaxPool2d(kernel_size=2, stride=2)  # (14, 14, 32)
    
            self.conv_2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)  # (14, 14, 64)
            self.relu_2 = nn.ReLU()
            self.pool_2 = nn.MaxPool2d(kernel_size=2, stride=2)  # (7, 7, 64)
    
            self.fc_1 = nn.Linear(7 * 7 * 64, 10)
    
        def forward(self, x):
            x = self.conv_1(x)
            x = self.relu_1(x)
            x = self.pool_1(x)
    
            x = self.conv_2(x)
            x = self.relu_2(x)
            x = self.pool_2(x)
            
            x = x.view(x.size(0), -1)
            x = self.fc_1(x)
    
            return x

 

  • model training. 최적화 할 parameter의 범위를 지정하거나 목록으로 제공한다.
    def objective(trial):
        model = CNN().to(device)
        criterion = nn.CrossEntropyLoss().to(device)
        
        # 최적화 할 parameter를 목록으로 제공
        optimizer_name = trial.suggest_categorical("optimizer", ["AdamW", "RMSprop", "SGD"])
        
        # 최적화 할 parameter 범위 지정
        learning_rate = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
        
        optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=learning_rate)
    
        for epoch in range(num_epochs):
            model.train()
            correct_count = 0
            for inputs, labels in train_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)
    
                optimizer.zero_grad()
    
                outputs = model(inputs)
                preds = torch.argmax(outputs, dim=-1)
                loss = criterion(outputs, labels)
    
                loss.backward()
                optimizer.step()
    
                correct_count += torch.sum(preds == labels)
                accuracy = correct_count.double() / len(train_loader)
    
            # print(f'[Epoch {epoch+1}/{num_epochs}] acc: {accuracy:.3f}')
    
            trial.report(accuracy, epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()
    
        return accuracy
     위 함수는 다음 조건을 반드시 만족해야 한다.
    1. trial을 매개 변수로 받아야 한다.
    2. 평가 지표를 return 해야 한다.

 

  • parameter 최적화
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=50, timeout=600)  # n_trial: 최적화 수행 횟수
    
    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])
    
    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))
    
    print("Best trial:")
    trial = study.best_trial
    
    print("  Value: ", trial.value)
    
    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))

최적화 결과

 

시각화

시각화 함수 종류는 Link에서 확인가능하다.

Jupyter lab을 사용하는 경우 시각화 결과가 제대로 출력되지 않으므로 installation guide를 따른다.
또는 matplotlib를 이용한 시각화도 가능하다.

 

다음은 몇 가지 예시이다.

 

  • 시행 횟수에 따른 변화
    optuna.visualization.plot_optimization_history(study)
    # 또는 optuna.visualization.matplotlib.plot_optimization_history(study)

 

  • parameter 조합에 따른 성능 확인 (색이 진할수록 성능이 좋다)
    optuna.visualization.plot_parallel_coordinate(study)
    # 또는 optuna.visualization.matplotlib.plot_parallel_coordinate(study)

 

  • 성능에 영향을 미치는 요소 확인 (값이 클수록 영향을 많이 준다)
    optuna.visualization.plot_param_importances(study)
    # 또는 optuna.visualization.matplotlib.plot_param_importances(study)