상세 컨텐츠

본문 제목

Gemma3 기반 LoRA 튜닝으로 주식 예측 모델 만들기

머신러닝/주가 예측

by byoelcardi 2025. 4. 17. 20:40

본문

최근 오픈된 Google Gemma 모델은 경량화된 구조에도 불구하고 상당한 성능을 보여주고 있습니다. 이번 포스트에서는 Gemma 3-7B 모델을 기반으로, 주식 데이터를 활용한 파인튜닝 코드를 소개하고자 합니다. 특히, LoRA(PEFT) 기법을 활용하여 빠르고 가볍게 학습 가능한 구조를 구현하였습니다.

 

학습 데이터 생성

import pandas as pd
import json
import os
from datetime import datetime
import yfinance as yf
import argparse

def load_stock_data(excel_file):
    """
    Excel 파일에서 주식 데이터를 불러옵니다.
    """
    df = pd.read_excel(excel_file)
    # 날짜 열을 인덱스로 사용하는지 확인
    if df.index.name != 'Date' and 'Date' in df.columns:
        df.set_index('Date', inplace=True)
    return df

def create_sliding_window_data(df, window_size=5):
    """
    주어진 데이터프레임에서 슬라이딩 윈도우 방식으로 학습 데이터를 생성합니다.
    window_size일간의 종가를 입력으로, 다음날 종가를 출력으로 설정합니다.
    """
    data = []
    
    # Close 컬럼이 있는지 확인
    if 'Close' not in df.columns:
        raise ValueError("데이터에 'Close' 컬럼이 없습니다.")
    
    # 종가 데이터 추출
    close_prices = df['Close'].to_numpy()
    
    # 데이터 생성 (슬라이딩 윈도우)
    for i in range(len(close_prices) - window_size):
        input_window = close_prices[i:i+window_size]
        target_price = close_prices[i+window_size]
        
        # 입력 형식: "최근 5일간 종가: [가격1, 가격2, 가격3, 가격4, 가격5]"
        # 이 형식은 모델이 이해하기 쉽도록 설계되었습니다
        prompt = f"최근 {window_size}일간 종가: {input_window.tolist()}"
        
        # 출력 형식: "다음날 종가: 가격"
        response = f"다음날 종가: {target_price}"
        
        data.append({
            "prompt": prompt,
            "response": response
        })
    
    return data

def save_to_jsonl(data, output_file):
    """
    데이터를 JSONL 형식으로 저장합니다.
    기존 파일이 있으면 덮어씁니다.
    """
    # 디렉토리가 없으면 생성
    os.makedirs(os.path.dirname(output_file), exist_ok=True)
    
    # 파일이 존재하면 덮어쓰기 모드로 열기
    with open(output_file, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')
    
    print(f"학습 데이터가 {output_file}에 저장되었습니다. 총 {len(data)}개의 샘플이 생성되었습니다.")

def get_yfinance_data(ticker, start_date, end_date):
    stock = yf.Ticker(ticker)
    data = stock.history(start=start_date, end=end_date)
    return data

def save_to_excel(data, file_name):
    # Remove timezone information from the datetime index
    data_to_save = data.copy()
    data_to_save.index = data_to_save.index.tz_localize(None)
    
    # stock_data 디렉토리에 저장
    output_path = os.path.join("stock_data", file_name)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    data_to_save.to_excel(output_path)
    print(f"Data successfully saved to {output_path}")

def process_ticker(ticker, start_date, end_date):
    """
    단일 티커에 대한 데이터 처리를 수행합니다.
    """
    print(f"\n=== {ticker.upper()} 데이터 처리 시작 ===")
    data = get_yfinance_data(ticker, start_date, end_date)
    save_to_excel(data, f"{ticker}_stock_price_data.xlsx")
    data_dir = "stock_data"
    excel_file = os.path.join(data_dir, f"{ticker}_stock_price_data.xlsx")
    
    # 데이터 로드
    print(f"{excel_file} 파일에서 주식 데이터를 불러오는 중...")
    stock_data = load_stock_data(excel_file)
    
    # 학습 데이터 생성 (5일 윈도우)
    print("학습 데이터 생성 중...")
    training_data = create_sliding_window_data(stock_data, window_size=5)
    
    # JSONL 파일로 저장
    output_file = os.path.join(data_dir, f"{ticker}_stock_price_prediction_train.jsonl")
    save_to_jsonl(training_data, output_file)
    
    # 샘플 데이터 출력
    print("\n샘플 데이터:")
    for i in range(min(3, len(training_data))):
        print(f"입력: {training_data[i]['prompt']}")
        print(f"출력: {training_data[i]['response']}")
        print()
    print(f"=== {ticker.upper()} 데이터 처리 완료 ===\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='주식 데이터를 다운로드하고 학습 데이터를 생성합니다.')
    parser.add_argument('--ticker', type=str, help='처리할 주식 티커 심볼 (예: meta, msft, aapl)')
    parser.add_argument('--all', action='store_true', help='모든 티커에 대해 처리합니다.')
    
    args = parser.parse_args()
    
    # 날짜 설정
    start_date = "2023-01-01"
    end_date = datetime.now().strftime("%Y-%m-%d")
    
    try:
        if args.all:
            # 모든 티커 처리
            tickers = ["meta", "msft", "aapl"]
            for ticker in tickers:
                process_ticker(ticker, start_date, end_date)
        elif args.ticker:
            # 단일 티커 처리
            process_ticker(args.ticker, start_date, end_date)
        else:
            parser.print_help()
            exit(1)
    except Exception as e:
        print(f"오류 발생: {e}")
        exit(1)

 

해당 코드는 create_stock_training_data.py라고 따로 저장을 하고 학습 모델에서 따로 import 진행하였습니다.

모델 로드

model_id = "google/gemma-3-4b-it"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")

 

PEFT with LoRA 설정

peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    task_type=TaskType.CAUSAL_LM,
    lora_dropout=0.1,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)
model = get_peft_model(model, peft_config)

 

 

  • LoRA는 attention layer의 일부만 학습하여 전체 학습 효율을 높입니다.
  • target_modules는 Gemma의 attention 구성요소입니다.

학습 인자

training_args = TrainingArguments(
    output_dir=f"./gemma-lora-finetuned-{ticker}",
    per_device_train_batch_size=1,
    num_train_epochs=5,
    learning_rate=1e-4,
    save_strategy="no",  # 체크포인트 저장 안 함
    report_to="none"
)

 

 

 

코드 전문 (학습 코드)

import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
from peft import get_peft_model, LoraConfig, TaskType
import argparse
import subprocess
import os

def create_training_data(ticker=None):
    """
    create_stock_training_data.py를 실행하여 학습 데이터를 생성합니다.
    """
    script_path = "create_stock_training_data.py"
    print("\n=== 학습 데이터 생성 시작 ===")
    
    try:
        if ticker:
            # 특정 티커에 대해서만 데이터 생성
            subprocess.run(["python", script_path, "--ticker", ticker], check=True)
        else:
            # 모든 티커에 대해 데이터 생성
            subprocess.run(["python", script_path, "--all"], check=True)
        print("=== 학습 데이터 생성 완료 ===\n")
    except subprocess.CalledProcessError as e:
        print(f"데이터 생성 중 오류 발생: {e}")
        raise

def train_model_for_ticker(ticker):
    """
    주어진 티커에 대해 모델을 학습합니다.
    """
    # 학습 데이터 파일 경로 확인
    data_file = f"stock_data/{ticker}_stock_price_prediction_train.jsonl"
    if not os.path.exists(data_file):
        print(f"{data_file} 파일이 없습니다. 데이터를 생성합니다.")
        create_training_data(ticker)
    
    print(f"\n=== {ticker.upper()} 학습 시작 ===")
    
    # 1. 모델 & 토크나이저 불러오기
    model_id = "google/gemma-3-4b-it"  # 또는 "google/gemma-7b"
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
    )

    # 2. PEFT (LoRA) 구성
    peft_config = LoraConfig(
        r=16,
        lora_alpha=32,
        task_type=TaskType.CAUSAL_LM,
        lora_dropout=0.1,
        bias="none",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]  # Gemma 모델의 attention 레이어들
    )

    model = get_peft_model(model, peft_config)

    # 3. 학습 데이터셋 불러오기
    dataset = load_dataset("json", data_files={"train": data_file})

    def tokenize(example):
        prompt = example["prompt"]
        response = example["response"]
        
        # Gemma 모델을 위한 챗 형식으로 프롬프트와 응답을 구성
        full_text = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n{response}<end_of_turn>"
        
        # 토큰화 시 max_length를 좀 더 작게 설정 (데이터가 짧기 때문)
        tokenized = tokenizer(full_text, truncation=True, padding="max_length", max_length=512)
        tokenized["labels"] = tokenized["input_ids"]
        return tokenized

    tokenized_dataset = dataset["train"].map(tokenize)

    # 4. 트레이닝 설정
    training_args = TrainingArguments(
        output_dir=f"./gemma-lora-finetuned-{ticker}",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=1,
        num_train_epochs=5,
        learning_rate=1e-4,
        fp16=False,
        logging_steps=1,
        save_strategy="no",  # checkpoint 저장 비활성화
        report_to="none"
    )

    trainer = Trainer(
        model=model,
        train_dataset=tokenized_dataset,
        args=training_args,
        data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False)
    )

    # 5. 학습 시작
    trainer.train()

    # 6. 모델 저장
    model.save_pretrained(f"./gemma-lora-finetuned-{ticker}")
    tokenizer.save_pretrained(f"./gemma-lora-finetuned-{ticker}")
    
    print(f"=== {ticker.upper()} 학습 완료 ===\n")


if __name__ == "__main__":
    # 명령행 인자 파싱
    tickers = ["meta", "msft", "aapl"]
    for ticker in tickers:
        train_model_for_ticker(ticker)

 

학습 ticker는 meta, msft, aapl로 코드 작업 진행하였고 추가적으로 원하시는 Ticker가 있을 경우에 ticker 리스트에 추가로 기입 해주시면 됩니다.

관련글 더보기