본문 바로가기
AI Theory/DL Basic

HuggingFace custom dataset 저장하는 방법

by climba 2023. 6. 18.

이번에 수업에서 image classification project를 하다보니 colab으로 매 번 개별 이미지 파일들을 dataset으로 만드는게 시간도 오래걸리고, 무엇보다 vscode에서 사용하기 어려울 것 같아 불편하였다. 예전에 프로젝트를 할 때 huggingface dataset에 올려놓고 사용한 적이 있었는데, 그 방법을 다시 상기시킬 겸 적어보려한다.

 

(사실 아래 링크에 매우 자세한 방법이 나와있습니다. 영어가 편하신 분들은 공식 레포를 읽는 것이 더 도움 될 수 있습니다!)

https://huggingface.co/docs/datasets/image_dataset

 

Create an image dataset

There are two methods for creating and sharing an image dataset. This guide will show you how to: You can control access to your dataset by requiring users to share their contact information first. Check out the Gated datasets guide for more information ab

huggingface.co

우선 데이터셋의 구조는 train 안에 총 8개의 label에 해당하는 폴더가 있고, 각 폴더에 이미지들이 모여있는 형태이다.

colab 기준으로 살펴보면, 대충 아래와 같은 형태이다.

 

/content/my_home/MyDrive/math_for_ai_project/math_for_ai_project/train/0/0000.jpg

/content/my_home/MyDrive/math_for_ai_project/math_for_ai_project/train/0/0001.jpg

/content/my_home/MyDrive/math_for_ai_project/math_for_ai_project/train/0/0002.jpg

...

/content/my_home/MyDrive/math_for_ai_project/math_for_ai_project/train/7/0399.jpg

 

from datasets import load_dataset

dataset = (load_dataset("imagefolder", data_dir="/content/drive/MyDrive/math_for_ai_project/math_for_ai_project/train")['train'].train_test_split(train_size=3000, test_size=200))
print(dataset)

pip install로 datasets 라이브러리를 설치한 후 위와 같이 train과 test(valid)를 적당한 비율로 나눠준다.

 

끝난 뒤 dataset을 출력해보면 아래와 같이 나올 것이다.

from datasets import load_dataset
dataset.push_to_hub("climba/image-classification-8class")

그 후 위 코드를 실행시켜주면 되는데 push_to_hub("huggingface 아이디/저장할 경로 이름(새로 지정)")로 실행하면 된다.

다만 실행 시 위와 같은 오류가 뜰 것인데, 아래 코드를 실행해준 다음, huggingface 개인정보(Profile - Settings - Access Tokens)에 있는 token(write token)을 입력해주면 된다.

from huggingface_hub import notebook_login

notebook_login()

그 후 다시 push_to_hub() 코드를 실행해보면, 정상적으로 잘 작동할 것이다. (다만 데이터셋 크기에 따라 시간이 조금 소요된다.)

잘 불러와졌는지는 아래 코드로 확인하면 된다.

from datasets import load_dataset
my_dataset = load_dataset("climba/image-classification-8class")
print(my_dataset)

댓글