Skip to content

Commit 23238ba

Browse files
Ecursoragent
authored andcommitted
LoRA-in-generation: API, pipeline shift, and Create UI
- api/generate: lora_adapters endpoint, pass loraNameOrPath/loraWeight to generation; defaults steps 65, guidance 4.0 - cdmf_pipeline_ace_step: shift parameter (default 6.0) for scheduler - generate_ace: pass shift 6.0 into pipeline - CreatePanel: LoRA adapter selector and weight - TrainingPanel: copy noting LoRA appears in Create after training - api.ts / types.ts: LoRA types and getLoraAdapters() Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 971cd12 commit 23238ba

7 files changed

Lines changed: 126 additions & 20 deletions

File tree

api/generate.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from flask import Blueprint, jsonify, request, send_file
1313

1414
from cdmf_paths import get_output_dir, get_user_data_dir
15+
from cdmf_tracks import list_lora_adapters
1516

1617
bp = Blueprint("api_generate", __name__)
1718

@@ -103,17 +104,16 @@ def _run_generation(job_id: str) -> None:
103104
duration = 60
104105
# UI may send duration=-1 or 0; clamp to valid range (15–240s)
105106
duration = max(15, min(240, duration))
107+
# Guide: 65 steps + CFG 4.0 for best quality; low CFG reduces artifacts (see community guide).
106108
try:
107-
steps = int(params.get("inferenceSteps") or 55)
109+
steps = int(params.get("inferenceSteps") or 65)
108110
except (TypeError, ValueError):
109-
steps = 55
111+
steps = 65
110112
steps = max(1, min(100, steps))
111-
# Doc recommends 7.0 default; higher helps adherence to caption and reference (see ACE-Step-INFERENCE.md).
112113
try:
113-
guidance_default = 7.0 if src_audio_path else 6.0
114-
guidance_scale = float(params.get("guidanceScale") or guidance_default)
114+
guidance_scale = float(params.get("guidanceScale") or 4.0)
115115
except (TypeError, ValueError):
116-
guidance_scale = 7.0 if src_audio_path else 6.0
116+
guidance_scale = 4.0
117117
try:
118118
seed = int(params.get("seed") or 0)
119119
except (TypeError, ValueError):
@@ -159,6 +159,14 @@ def _run_generation(job_id: str) -> None:
159159
repaint_end = -1.0
160160
# -1 means "end of audio"; generate_track_ace converts to target duration
161161

162+
# LoRA adapter (optional): path or folder name under custom_lora
163+
lora_name_or_path = (params.get("loraNameOrPath") or params.get("lora_name_or_path") or "").strip()
164+
try:
165+
lora_weight = float(params.get("loraWeight") or params.get("lora_weight") or 0.75)
166+
except (TypeError, ValueError):
167+
lora_weight = 0.75
168+
lora_weight = max(0.0, min(2.0, lora_weight))
169+
162170
if src_audio_path:
163171
logging.info("[API generate] Using reference audio: %s (task=%s, audio2audio=%s)", src_audio_path, task, audio2audio_enable)
164172
else:
@@ -193,6 +201,8 @@ def _run_generation(job_id: str) -> None:
193201
repaint_end=repaint_end,
194202
vocal_gain_db=0.0,
195203
instrumental_gain_db=0.0,
204+
lora_name_or_path=lora_name_or_path or None,
205+
lora_weight=lora_weight,
196206
)
197207

198208
wav_path = summary.get("wav_path")
@@ -234,6 +244,17 @@ def _run_generation(job_id: str) -> None:
234244
break
235245

236246

247+
@bp.route("/lora_adapters", methods=["GET"])
248+
def get_lora_adapters():
249+
"""GET /api/generate/lora_adapters — list LoRA adapters (e.g. from Training or custom_lora)."""
250+
try:
251+
adapters = list_lora_adapters()
252+
return jsonify({"adapters": adapters})
253+
except Exception as e:
254+
logging.exception("[API generate] list_lora_adapters failed: %s", e)
255+
return jsonify({"adapters": []})
256+
257+
237258
@bp.route("", methods=["POST"], strict_slashes=False)
238259
@bp.route("/", methods=["POST"], strict_slashes=False)
239260
def create_job():

cdmf_pipeline_ace_step.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -921,6 +921,7 @@ def flowedit_diffusion_process(
921921
n_max=1.0,
922922
n_avg=1,
923923
scheduler_type="euler",
924+
shift: float = 6.0,
924925
):
925926

926927
do_classifier_free_guidance = True
@@ -932,7 +933,7 @@ def flowedit_diffusion_process(
932933

933934
scheduler = FlowMatchEulerDiscreteScheduler(
934935
num_train_timesteps=1000,
935-
shift=3.0,
936+
shift=shift,
936937
)
937938

938939
T_steps = infer_steps
@@ -1111,25 +1112,26 @@ def add_latents_noise(
11111112
noise,
11121113
scheduler_type,
11131114
infer_steps,
1115+
shift: float = 6.0,
11141116
):
11151117

11161118
bsz = gt_latents.shape[0]
11171119
if scheduler_type == "euler":
11181120
scheduler = FlowMatchEulerDiscreteScheduler(
11191121
num_train_timesteps=1000,
1120-
shift=3.0,
1122+
shift=shift,
11211123
sigma_max=sigma_max,
11221124
)
11231125
elif scheduler_type == "heun":
11241126
scheduler = FlowMatchHeunDiscreteScheduler(
11251127
num_train_timesteps=1000,
1126-
shift=3.0,
1128+
shift=shift,
11271129
sigma_max=sigma_max,
11281130
)
11291131
elif scheduler_type == "pingpong":
11301132
scheduler = FlowMatchPingPongScheduler(
11311133
num_train_timesteps=1000,
1132-
shift=3.0,
1134+
shift=shift,
11331135
sigma_max=sigma_max
11341136
)
11351137

@@ -1180,6 +1182,7 @@ def text2music_diffusion_process(
11801182
audio2audio_enable=False,
11811183
ref_audio_strength=0.5,
11821184
ref_latents=None,
1185+
shift: float = 6.0,
11831186
):
11841187

11851188
logger.info(
@@ -1212,17 +1215,17 @@ def text2music_diffusion_process(
12121215
if scheduler_type == "euler":
12131216
scheduler = FlowMatchEulerDiscreteScheduler(
12141217
num_train_timesteps=1000,
1215-
shift=3.0,
1218+
shift=shift,
12161219
)
12171220
elif scheduler_type == "heun":
12181221
scheduler = FlowMatchHeunDiscreteScheduler(
12191222
num_train_timesteps=1000,
1220-
shift=3.0,
1223+
shift=shift,
12211224
)
12221225
elif scheduler_type == "pingpong":
12231226
scheduler = FlowMatchPingPongScheduler(
12241227
num_train_timesteps=1000,
1225-
shift=3.0,
1228+
shift=shift,
12261229
)
12271230

12281231
frame_length = int(duration * 44100 / 512 / 8)
@@ -1400,6 +1403,7 @@ def text2music_diffusion_process(
14001403
noise=target_latents,
14011404
scheduler_type=scheduler_type,
14021405
infer_steps=infer_steps,
1406+
shift=shift,
14031407
)
14041408

14051409
attention_mask = torch.ones(bsz, frame_length, device=self.device, dtype=self.dtype)
@@ -1876,6 +1880,7 @@ def __call__(
18761880
save_path: str = None,
18771881
batch_size: int = 1,
18781882
debug: bool = False,
1883+
shift: float = 6.0,
18791884
):
18801885

18811886
start_time = time.time()
@@ -2029,6 +2034,7 @@ def __call__(
20292034
n_max=edit_n_max,
20302035
n_avg=edit_n_avg,
20312036
scheduler_type=scheduler_type,
2037+
shift=shift,
20322038
)
20332039
else:
20342040
target_latents = self.text2music_diffusion_process(
@@ -2062,6 +2068,7 @@ def __call__(
20622068
audio2audio_enable=audio2audio_enable,
20632069
ref_audio_strength=ref_audio_strength,
20642070
ref_latents=ref_latents,
2071+
shift=shift,
20652072
)
20662073

20672074
end_time = time.time()

generate_ace.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -877,6 +877,7 @@ def _run_ace_text2music(
877877
"batch_size": 1,
878878
"save_path": str(output_path),
879879
"debug": False,
880+
"shift": 6.0,
880881
}
881882

882883
# Wire up reference vs source audio per ACE-Step pipeline:

ui/components/CreatePanel.tsx

Lines changed: 67 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import React, { useState, useEffect, useRef, useCallback } from 'react';
2-
import { Sparkles, ChevronDown, Settings2, Trash2, Music2, Sliders, Dices, Hash, RefreshCw, Plus, Upload, Play, Pause, Info } from 'lucide-react';
2+
import { Sparkles, ChevronDown, Settings2, Trash2, Music2, Sliders, Dices, Hash, RefreshCw, Plus, Upload, Play, Pause, Info, Loader2 } from 'lucide-react';
33
import { GenerationParams, Song } from '../types';
44
import { useAuth } from '../context/AuthContext';
5-
import { generateApi } from '../services/api';
5+
import { generateApi, type LoraAdapter } from '../services/api';
66

77
interface ReferenceTrack {
88
id: string;
@@ -142,12 +142,12 @@ export const CreatePanel: React.FC<CreatePanelProps> = ({ onGenerate, isGenerati
142142
const [duration, setDuration] = useState(-1);
143143
const [batchSize, setBatchSize] = useState(1);
144144
const [bulkCount, setBulkCount] = useState(1); // Number of independent generation jobs to queue
145-
const [guidanceScale, setGuidanceScale] = useState(7.0);
145+
const [guidanceScale, setGuidanceScale] = useState(4.0);
146146
const [randomSeed, setRandomSeed] = useState(true);
147147
const [seed, setSeed] = useState(-1);
148148
const [thinking, setThinking] = useState(false); // Default false for GPU compatibility
149149
const [audioFormat, setAudioFormat] = useState<'mp3' | 'flac'>('mp3');
150-
const [inferenceSteps, setInferenceSteps] = useState(8);
150+
const [inferenceSteps, setInferenceSteps] = useState(65);
151151
const [inferMethod, setInferMethod] = useState<'ode' | 'sde'>('ode');
152152
const [shift, setShift] = useState(3.0);
153153

@@ -171,6 +171,10 @@ export const CreatePanel: React.FC<CreatePanelProps> = ({ onGenerate, isGenerati
171171
const [cfgIntervalStart, setCfgIntervalStart] = useState(0.0);
172172
const [cfgIntervalEnd, setCfgIntervalEnd] = useState(1.0);
173173
const [customTimesteps, setCustomTimesteps] = useState('');
174+
const [loraAdapters, setLoraAdapters] = useState<LoraAdapter[]>([]);
175+
const [loraLoading, setLoraLoading] = useState(false);
176+
const [loraNameOrPath, setLoraNameOrPath] = useState('');
177+
const [loraWeight, setLoraWeight] = useState(0.75);
174178
const [useCotMetas, setUseCotMetas] = useState(true);
175179
const [useCotCaption, setUseCotCaption] = useState(true);
176180
const [useCotLanguage, setUseCotLanguage] = useState(true);
@@ -249,6 +253,17 @@ export const CreatePanel: React.FC<CreatePanelProps> = ({ onGenerate, isGenerati
249253
}
250254
}, [referenceAudioUrl, sourceAudioUrl]);
251255

256+
const fetchLoraAdapters = useCallback(() => {
257+
setLoraLoading(true);
258+
generateApi.getLoraAdapters()
259+
.then((res) => setLoraAdapters(res.adapters || []))
260+
.catch(() => setLoraAdapters([]))
261+
.finally(() => setLoraLoading(false));
262+
}, []);
263+
264+
// Fetch LoRA adapters on mount (Training output + custom_lora)
265+
useEffect(() => { fetchLoraAdapters(); }, [fetchLoraAdapters]);
266+
252267
useEffect(() => {
253268
const handleMouseMove = (e: MouseEvent) => {
254269
if (!isResizing) return;
@@ -629,6 +644,8 @@ export const CreatePanel: React.FC<CreatePanelProps> = ({ onGenerate, isGenerati
629644
cfgIntervalStart,
630645
cfgIntervalEnd,
631646
customTimesteps: customTimesteps.trim() || undefined,
647+
loraNameOrPath: loraNameOrPath.trim() || undefined,
648+
loraWeight,
632649
useCotMetas,
633650
useCotCaption,
634651
useCotLanguage,
@@ -1373,7 +1390,7 @@ export const CreatePanel: React.FC<CreatePanelProps> = ({ onGenerate, isGenerati
13731390
<div className="flex items-center justify-between">
13741391
<span className="inline-flex items-center gap-1.5">
13751392
<label className="text-xs font-medium text-zinc-600 dark:text-zinc-400">Inference Steps</label>
1376-
<InfoTooltip text="Number of denoising steps. Turbo: 1–20 (8 recommended). More steps = better quality, slower." />
1393+
<InfoTooltip text="Number of denoising steps. 65 recommended for quality (low CFG + high steps). Turbo: 8–20." />
13771394
</span>
13781395
<span className="text-xs font-mono text-zinc-900 dark:text-white bg-zinc-100 dark:bg-black/20 px-2 py-0.5 rounded">{inferenceSteps}</span>
13791396
</div>
@@ -1386,7 +1403,7 @@ export const CreatePanel: React.FC<CreatePanelProps> = ({ onGenerate, isGenerati
13861403
onChange={(e) => setInferenceSteps(Number(e.target.value))}
13871404
className="w-full h-2 bg-zinc-200 dark:bg-zinc-700 rounded-lg appearance-none cursor-pointer accent-pink-500"
13881405
/>
1389-
<p className="text-[10px] text-zinc-500">More steps = better quality, slower (8 recommended for turbo)</p>
1406+
<p className="text-[10px] text-zinc-500">65 recommended for quality; more steps = slower</p>
13901407
</div>
13911408

13921409
{/* Guidance Scale */}
@@ -1442,6 +1459,50 @@ export const CreatePanel: React.FC<CreatePanelProps> = ({ onGenerate, isGenerati
14421459
</div>
14431460
</div>
14441461

1462+
{/* LoRA adapter (Training / custom_lora) */}
1463+
<div className="grid grid-cols-2 gap-3">
1464+
<div className="space-y-1.5">
1465+
<span className="inline-flex items-center gap-1.5">
1466+
<label className="text-xs font-medium text-zinc-600 dark:text-zinc-400">LoRA adapter</label>
1467+
<InfoTooltip text="Use a custom LoRA (e.g. from Training). After training, click Refresh to see new adapters." />
1468+
<button
1469+
type="button"
1470+
onClick={fetchLoraAdapters}
1471+
disabled={loraLoading}
1472+
className="p-0.5 rounded hover:bg-zinc-200 dark:hover:bg-zinc-600 disabled:opacity-50"
1473+
title="Refresh LoRA list"
1474+
>
1475+
{loraLoading ? <Loader2 size={12} className="animate-spin" /> : <RefreshCw size={12} />}
1476+
</button>
1477+
</span>
1478+
<select
1479+
value={loraNameOrPath}
1480+
onChange={(e) => setLoraNameOrPath(e.target.value)}
1481+
className="w-full bg-zinc-50 dark:bg-black/20 border border-zinc-200 dark:border-white/10 rounded-lg px-2 py-1.5 text-xs text-zinc-900 dark:text-white focus:outline-none"
1482+
>
1483+
<option value="">None</option>
1484+
{loraAdapters.map((a) => (
1485+
<option key={a.path} value={a.path}>{a.name}</option>
1486+
))}
1487+
</select>
1488+
</div>
1489+
<div className="space-y-1.5">
1490+
<span className="inline-flex items-center gap-1.5">
1491+
<label className="text-xs font-medium text-zinc-600 dark:text-zinc-400">LoRA weight</label>
1492+
<InfoTooltip text="Strength of the LoRA (0–2). 0.75 is a good default; lower = subtler, higher = stronger style." />
1493+
</span>
1494+
<input
1495+
type="number"
1496+
min={0}
1497+
max={2}
1498+
step={0.05}
1499+
value={loraWeight}
1500+
onChange={(e) => setLoraWeight(Number(e.target.value))}
1501+
className="w-full bg-zinc-50 dark:bg-black/20 border border-zinc-200 dark:border-white/10 rounded-lg px-2 py-1.5 text-xs text-zinc-900 dark:text-white focus:outline-none"
1502+
/>
1503+
</div>
1504+
</div>
1505+
14451506
{/* Seed */}
14461507
<div className="space-y-2">
14471508
<div className="flex items-center justify-between">

ui/components/TrainingPanel.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ export const TrainingPanel: React.FC<TrainingPanelProps> = ({ onTracksUpdated: _
250250
<h2 className="text-lg font-semibold">Train Custom LoRA</h2>
251251
</div>
252252
<p className="text-sm text-zinc-500 dark:text-zinc-400 mb-4">
253-
Run LoRA training on your dataset. Dataset folder must be under <code className="bg-zinc-200 dark:bg-zinc-700 px-1 rounded">training_datasets</code>. Use Browse to select a folder.
253+
Run LoRA training on your dataset. Dataset folder must be under <code className="bg-zinc-200 dark:bg-zinc-700 px-1 rounded">training_datasets</code>. Use Browse to select a folder. When training finishes, the LoRA is saved automatically and will appear in <strong>Create → LoRA adapter</strong> (click Refresh there if needed).
254254
</p>
255255

256256
{aceReady === false && aceState !== 'downloading' && (

ui/services/api.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,10 @@ export interface GenerationParams {
264264
completeTrackClasses?: string[];
265265
isFormatCaption?: boolean;
266266
outputDir?: string;
267+
/** LoRA adapter: folder name (from list) or full path. Used for ACE-Step generation. */
268+
loraNameOrPath?: string;
269+
/** LoRA weight 0–2. Default 0.75. */
270+
loraWeight?: number;
267271
}
268272

269273
export interface GenerationJob {
@@ -281,6 +285,12 @@ export interface GenerationJob {
281285
error?: string;
282286
}
283287

288+
export interface LoraAdapter {
289+
name: string;
290+
path: string;
291+
size_bytes?: number | null;
292+
}
293+
284294
export const generateApi = {
285295
startGeneration: (params: GenerationParams, token: string): Promise<GenerationJob> =>
286296
api('/api/generate', { method: 'POST', body: params, token }),
@@ -291,6 +301,10 @@ export const generateApi = {
291301
getHistory: (token: string): Promise<{ jobs: GenerationJob[] }> =>
292302
api('/api/generate/history', { token }),
293303

304+
/** List LoRA adapters (Training output and custom_lora folder). */
305+
getLoraAdapters: (): Promise<{ adapters: LoraAdapter[] }> =>
306+
api('/api/generate/lora_adapters'),
307+
294308
uploadAudio: async (file: File, token: string): Promise<{ url: string; key: string }> => {
295309
const url = `${API_BASE}/api/generate/upload-audio`;
296310
console.log('[API] POST', url);

ui/types.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ export interface GenerationParams {
110110
trackName?: string;
111111
completeTrackClasses?: string[];
112112
isFormatCaption?: boolean;
113+
loraNameOrPath?: string;
114+
loraWeight?: number;
113115
}
114116

115117
export interface PlayerState {

0 commit comments

Comments
 (0)