1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
| from knowledge_storm.collaborative_storm.engine import CollaborativeStormLMConfigs, RunnerArgument, CoStormRunner from knowledge_storm.lm import LitellmModel from knowledge_storm.logging_wrapper import LoggingWrapper from knowledge_storm.rm import BingSearch
lm_config: CollaborativeStormLMConfigs = CollaborativeStormLMConfigs() openai_kwargs = { "api_key": os.getenv("OPENAI_API_KEY"), "api_provider": "openai", "temperature": 1.0, "top_p": 0.9, "api_base": None, } question_answering_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs) discourse_manage_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) utterance_polishing_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs) warmstart_outline_gen_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs) question_asking_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs) knowledge_base_lm = LitellmModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)
lm_config.set_question_answering_lm(question_answering_lm) lm_config.set_discourse_manage_lm(discourse_manage_lm) lm_config.set_utterance_polishing_lm(utterance_polishing_lm) lm_config.set_warmstart_outline_gen_lm(warmstart_outline_gen_lm) lm_config.set_question_asking_lm(question_asking_lm) lm_config.set_knowledge_base_lm(knowledge_base_lm)
topic = input('Topic: ') runner_argument = RunnerArgument(topic=topic, ...) logging_wrapper = LoggingWrapper(lm_config) bing_rm = BingSearch(bing_search_api_key=os.environ.get("BING_SEARCH_API_KEY"), k=runner_argument.retrieve_top_k) costorm_runner = CoStormRunner(lm_config=lm_config, runner_argument=runner_argument, logging_wrapper=logging_wrapper, rm=bing_rm)
costorm_runner.warm_start() conv_turn = costorm_runner.step() costorm_runner.step(user_utterance="YOUR UTTERANCE HERE") costorm_runner.knowledge_base.reorganize() article = costorm_runner.generate_report() print(article)
|